Skip to content

Commit 7a86ac5

Browse files
committed
first commit
0 parents  commit 7a86ac5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+16906
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
hub/checkpoints/*

.vscode/settings.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"[python]": {
3+
"editor.defaultFormatter": "ms-python.black-formatter"
4+
},
5+
"python.formatting.provider": "none"
6+
}

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Articulated Object Neural Radiance Field
2+
3+
# :computer: Overview
4+
Experimental Repo for Modelling Neural Radiance Field for Articulated Objects. Currently Supported Experiments:
5+
6+
- Sapien Dataset (Single Instance Overfitting)
7+
8+
- Sapien Dataset (Single Instance Articulated Overfitting)
9+
10+
- Sapien Dataset (Single Instance Auto-Encoder Articulated NeRF)
11+
12+
- Future: Sapien Dataset (Single Instance Auto-Decoder Articulated NeRF)
13+
14+
15+
# :computer: Installation
16+
17+
## Hardware
18+
19+
* OS: Ubuntu 18.04
20+
* NVIDIA GPU with **CUDA>=10.2** (tested with 1 RTX2080Ti)
21+
22+
## Software
23+
24+
* Clone this repo by `git clone --recursive https://github.com/zubair-irshad/articulated-object-nerf`
25+
* Python>=3.7 (installation via [anaconda](https://www.anaconda.com/distribution/) is recommended, use `conda create -n ao-nerf python=3.7` to create a conda environment and activate it by `conda activate nerf_pl`)
26+
* Python libraries
27+
* Install core requirements by `pip install -r requirements.txt`
28+
29+
# :key: Training
30+
31+
32+
# :key: Evaluation
33+
34+
35+
# :key: Generate Sapien Dataset
36+
* Coming Soon
37+

datasets/ray_utils.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
import torch
2+
from kornia import create_meshgrid
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
import numba as nb
6+
7+
def homogenise_np(p):
8+
_1 = np.ones((p.shape[0], 1), dtype=p.dtype)
9+
return np.concatenate([p, _1], axis=-1)
10+
11+
12+
def inside_axis_aligned_box(pts, box_min, box_max):
13+
return torch.all(torch.cat([pts >= box_min, pts <= box_max], dim=1), dim=1)
14+
15+
@nb.jit(nopython=True)
16+
def bbox_intersection_batch(bounds, rays_o, rays_d):
17+
N_rays = rays_o.shape[0]
18+
all_hit = np.empty((N_rays))
19+
all_near = np.empty((N_rays))
20+
all_far = np.empty((N_rays))
21+
for idx, (o, d) in enumerate(zip(rays_o, rays_d)):
22+
hit, near, far = bbox_intersection(bounds, o, d)
23+
# if hit == True:
24+
# print("hit", hit)
25+
all_hit[idx] = hit
26+
all_near[idx] = near
27+
all_far[idx] = far
28+
# return (h*w), (h*w, 3), (h*w, 3)
29+
return all_hit, all_near, all_far
30+
31+
@nb.jit(nopython=True)
32+
def bbox_intersection(bounds, orig, dir):
33+
# FIXME: currently, it is not working properly if the ray origin is inside the bounding box
34+
# https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
35+
# handle divide by zero
36+
dir[dir == 0] = 1.0e-14
37+
invdir = 1 / dir
38+
sign = (invdir < 0).astype(np.int64)
39+
40+
tmin = (bounds[sign[0]][0] - orig[0]) * invdir[0]
41+
tmax = (bounds[1 - sign[0]][0] - orig[0]) * invdir[0]
42+
43+
tymin = (bounds[sign[1]][1] - orig[1]) * invdir[1]
44+
tymax = (bounds[1 - sign[1]][1] - orig[1]) * invdir[1]
45+
46+
if tmin > tymax or tymin > tmax:
47+
return False, 0, 0
48+
if tymin > tmin:
49+
tmin = tymin
50+
if tymax < tmax:
51+
tmax = tymax
52+
53+
tzmin = (bounds[sign[2]][2] - orig[2]) * invdir[2]
54+
tzmax = (bounds[1 - sign[2]][2] - orig[2]) * invdir[2]
55+
56+
if tmin > tzmax or tzmin > tmax:
57+
return False, 0, 0
58+
if tzmin > tmin:
59+
tmin = tzmin
60+
if tzmax < tmax:
61+
tmax = tzmax
62+
# additionally, when the orig is inside the box, we return False
63+
if tmin < 0 or tmax < 0:
64+
return False, 0, 0
65+
return True, tmin, tmax
66+
67+
def homogenise_torch(p):
68+
_1 = torch.ones_like(p[:, [0]])
69+
return torch.cat([p, _1], dim=-1)
70+
71+
def get_ray_directions(H, W, focal):
72+
"""
73+
Get ray directions for all pixels in camera coordinate.
74+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
75+
ray-tracing-generating-camera-rays/standard-coordinate-systems
76+
77+
Inputs:
78+
H, W, focal: image height, width and focal length
79+
80+
Outputs:
81+
directions: (H, W, 3), the direction of the rays in camera coordinate
82+
"""
83+
grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
84+
i, j = grid.unbind(-1)
85+
# the direction here is without +0.5 pixel centering as calibration is not so accurate
86+
# see https://github.com/bmild/nerf/issues/24
87+
directions = \
88+
torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3)
89+
90+
return directions
91+
92+
93+
def get_rays_background(directions, c2w, coords):
94+
"""
95+
Get ray origin and normalized directions in world coordinate for all pixels in one image.
96+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
97+
ray-tracing-generating-camera-rays/standard-coordinate-systems
98+
99+
Inputs:
100+
directions: (H, W, 3) precomputed ray directions in camera coordinate
101+
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
102+
103+
Outputs:
104+
rays_o: (H*W, 3), the origin of the rays in world coordinate
105+
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
106+
"""
107+
# Rotate ray directions from camera coordinate to the world coordinate
108+
rays_d = directions @ c2w[:, :3].T # (H, W, 3)
109+
rays_d /= torch.norm(rays_d, dim=-1, keepdim=True)
110+
# The origin of all rays is the camera origin in world coordinate
111+
rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)
112+
113+
rays_o = rays_o[coords[:, 0], coords[:, 1]]
114+
rays_d = rays_d[coords[:, 0], coords[:, 1]]
115+
116+
return rays_o, rays_d
117+
118+
def get_rays(directions, c2w, output_view_dirs = False, output_radii = False):
119+
"""
120+
Get ray origin and normalized directions in world coordinate for all pixels in one image.
121+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
122+
ray-tracing-generating-camera-rays/standard-coordinate-systems
123+
124+
Inputs:
125+
directions: (H, W, 3) precomputed ray directions in camera coordinate
126+
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
127+
128+
Outputs:
129+
rays_o: (H*W, 3), the origin of the rays in world coordinate
130+
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
131+
"""
132+
# Rotate ray directions from camera coordinate to the world coordinate
133+
rays_d = directions @ c2w[:, :3].T # (H, W, 3)
134+
#rays_d /= torch.norm(rays_d, dim=-1, keepdim=True)
135+
# The origin of all rays is the camera origin in world coordinate
136+
rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)
137+
138+
if output_radii:
139+
rays_d_orig = directions @ c2w[:, :3].T
140+
dx = torch.sqrt(torch.sum((rays_d_orig[:-1, :, :] - rays_d_orig[1:, :, :]) ** 2, dim=-1))
141+
dx = torch.cat([dx, dx[-2:-1, :]], dim=0)
142+
radius = dx[..., None] * 2 / torch.sqrt(torch.tensor(12, dtype=torch.int8))
143+
radius = radius.reshape(-1)
144+
145+
if output_view_dirs:
146+
viewdirs = rays_d
147+
viewdirs /= torch.norm(viewdirs, dim=-1, keepdim=True)
148+
rays_d = rays_d.view(-1, 3)
149+
rays_o = rays_o.view(-1, 3)
150+
viewdirs = viewdirs.view(-1, 3)
151+
if output_radii:
152+
return rays_o, viewdirs, rays_d, radius
153+
else:
154+
return rays_o, viewdirs, rays_d
155+
else:
156+
rays_d /= torch.norm(rays_d, dim=-1, keepdim=True)
157+
rays_d = rays_d.view(-1, 3)
158+
rays_o = rays_o.view(-1, 3)
159+
return rays_o, rays_d
160+
161+
162+
def transform_rays_camera(rays_o, rays_d, c2w):
163+
"""
164+
Get ray origin and normalized directions in world coordinate for all pixels in one image.
165+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
166+
ray-tracing-generating-camera-rays/standard-coordinate-systems
167+
168+
Inputs:
169+
directions: (H, W, 3) precomputed ray directions in camera coordinate
170+
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
171+
172+
Outputs:
173+
rays_o: (H*W, 3), the origin of the rays in world coordinate
174+
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
175+
"""
176+
# Rotate ray directions from camera coordinate to the world coordinate
177+
rays_d = rays_d @ c2w[:, :3].T # (H, W, 3)
178+
rays_d /= torch.norm(rays_d, dim=-1, keepdim=True)
179+
# The origin of all rays is the camera origin in world coordinate
180+
rays_o = c2w[:, 3].expand(rays_d.shape) + rays_o # (H, W, 3)
181+
182+
rays_d = rays_d.view(-1, 3)
183+
rays_o = rays_o.view(-1, 3)
184+
185+
return rays_o, rays_d
186+
187+
def get_ndc_rays(H, W, focal, near, rays_o, rays_d):
188+
"""
189+
Transform rays from world coordinate to NDC.
190+
NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis.
191+
For detailed derivation, please see:
192+
http://www.songho.ca/opengl/gl_projectionmatrix.html
193+
https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf
194+
195+
In practice, use NDC "if and only if" the scene is unbounded (has a large depth).
196+
See https://github.com/bmild/nerf/issues/18
197+
198+
Inputs:
199+
H, W, focal: image height, width and focal length
200+
near: (N_rays) or float, the depths of the near plane
201+
rays_o: (N_rays, 3), the origin of the rays in world coordinate
202+
rays_d: (N_rays, 3), the direction of the rays in world coordinate
203+
204+
Outputs:
205+
rays_o: (N_rays, 3), the origin of the rays in NDC
206+
rays_d: (N_rays, 3), the direction of the rays in NDC
207+
"""
208+
# Shift ray origins to near plane
209+
t = -(near + rays_o[...,2]) / rays_d[...,2]
210+
rays_o = rays_o + t[...,None] * rays_d
211+
212+
# Store some intermediate homogeneous results
213+
ox_oz = rays_o[...,0] / rays_o[...,2]
214+
oy_oz = rays_o[...,1] / rays_o[...,2]
215+
216+
# Projection
217+
o0 = -1./(W/(2.*focal)) * ox_oz
218+
o1 = -1./(H/(2.*focal)) * oy_oz
219+
o2 = 1. + 2. * near / rays_o[...,2]
220+
221+
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz)
222+
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz)
223+
d2 = 1 - o2
224+
225+
rays_o = torch.stack([o0, o1, o2], -1) # (B, 3)
226+
rays_d = torch.stack([d0, d1, d2], -1) # (B, 3)
227+
228+
return rays_o, rays_d
229+
230+
def world_to_ndc(rotated_pcds, W, H, focal, near):
231+
232+
ox_oz = rotated_pcds[...,0] / rotated_pcds[...,2]
233+
oy_oz = rotated_pcds[...,1] / rotated_pcds[...,2]
234+
235+
# Projection
236+
o0 = -1./(W/(2.*focal)) * ox_oz
237+
o1 = -1./(H/(2.*focal)) * oy_oz
238+
o2 = 1. + 2. * near / rotated_pcds[...,2]
239+
240+
# oz_ox = rotated_pcds[...,2] / rotated_pcds[...,0]
241+
# oz_oy = rotated_pcds[...,2] / rotated_pcds[...,1]
242+
# # Projection
243+
# o0 = -(W/(2.*focal)) * oz_ox
244+
# o1 = -(H/(2.*focal)) * oz_oy
245+
# o2 = 1. + 2. * near / rotated_pcds[...,2]
246+
# print("o1.shape", o1.shape)
247+
rotated_pcd = np.concatenate((np.expand_dims(o0, axis=-1), np.expand_dims(o1, axis=-1), np.expand_dims(o2, axis=-1)), -1)
248+
return rotated_pcd
249+
250+
251+
252+
def get_rays_segmented(masks, class_ids, rays_o, rays_d, W, H, N_rays):
253+
seg_mask = np.zeros([H, W])
254+
for i in range(len(class_ids)):
255+
seg_mask[masks[:,:,i] > 0] = np.array(class_ids)[i]
256+
# print("classIds", class_ids)
257+
# print("seg masks", (seg_mask>0).flatten().shape, (seg_mask>0).shape)
258+
# print("(seg_mask>0).flatten()", np.count_nonzero((seg_mask>0).flatten()))
259+
# print("seg mask ", np.count_nonzero(seg_mask))
260+
# print("(seg_mask>0).flatten()", (seg_mask>0).flatten())
261+
# print("seg mask", seg_mask)
262+
# plt.imshow(seg_mask)
263+
# plt.show()
264+
265+
rays_rgb_obj = []
266+
rays_rgb_obj_dir = []
267+
class_ids.sort()
268+
269+
select_inds = []
270+
for i in range(len(class_ids)):
271+
rays_on_obj = np.where(seg_mask.flatten() == class_ids[i])[0]
272+
print("rays_on_obj", rays_on_obj.shape)
273+
rays_on_obj = rays_on_obj[np.random.choice(rays_on_obj.shape[0], N_rays)]
274+
select_inds.append(rays_on_obj)
275+
obj_mask = np.zeros(len(rays_o), np.bool)
276+
obj_mask[rays_on_obj] = 1
277+
rays_rgb_obj.append(rays_o[obj_mask])
278+
rays_rgb_obj_dir.append(rays_d[obj_mask])
279+
select_inds = np.concatenate(select_inds, axis=0)
280+
obj_mask = np.zeros(len(rays_o), np.bool)
281+
obj_mask[select_inds] = 1
282+
283+
# for i in range(len(class_ids)):
284+
# rays_on_obj = np.where(seg_mask.flatten() == class_ids[i])[0]
285+
# obj_mask = np.zeros(len(rays_o), np.bool)
286+
# obj_mask[rays_on_obj] = 1
287+
# rays_rgb_obj.append(rays_o[obj_mask])
288+
# rays_rgb_obj_dir.append(rays_d[obj_mask])
289+
290+
291+
# N_rays = min(N_rays, H * W)
292+
# select_inds = []
293+
# for b in range(len(class_ids)):
294+
# fg_inds = np.nonzero(seg_mask.flatten() == class_ids[b])
295+
# fg_inds = np.transpose(np.asarray(fg_inds))
296+
# fg_inds = fg_inds[np.random.choice(fg_inds.shape[0], N_rays)]
297+
# select_inds.append(fg_inds)
298+
# select_inds = np.concatenate(select_inds, axis=0)
299+
# j, i = select_inds[..., 0], select_inds[..., 1]
300+
301+
# select_inds = j * W + i
302+
303+
return rays_rgb_obj, rays_rgb_obj_dir, class_ids, (seg_mask>0).flatten()
304+
305+
306+
def convert_pose_PD_to_NeRF(C2W):
307+
308+
flip_axes = np.array([[1,0,0,0],
309+
[0,0,-1,0],
310+
[0,1,0,0],
311+
[0,0,0,1]])
312+
C2W = np.matmul(C2W, flip_axes)
313+
return C2W
314+
315+
def get_rays_mvs(H, W, focal, c2w):
316+
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij'
317+
ys, xs = ys.reshape(-1), xs.reshape(-1)
318+
319+
dirs = torch.stack([(xs-W/2)/focal, (ys-H/2)/focal, torch.ones_like(xs)], -1) # use 1 instead of -1
320+
rays_d = dirs @ c2w[:3,:3].t() # dot product, equals to: [c2w.dot(dir) for dir in dirs]
321+
# Translate camera frame's origin to the world frame. It is the origin of all rays.
322+
rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)
323+
rays_d = rays_d.view(-1, 3)
324+
rays_o = rays_o.view(-1, 3)
325+
return rays_o, rays_d

0 commit comments

Comments
 (0)