1
+ import torch
2
+ from kornia import create_meshgrid
3
+ import numpy as np
4
+ import matplotlib .pyplot as plt
5
+ import numba as nb
6
+
7
+ def homogenise_np (p ):
8
+ _1 = np .ones ((p .shape [0 ], 1 ), dtype = p .dtype )
9
+ return np .concatenate ([p , _1 ], axis = - 1 )
10
+
11
+
12
+ def inside_axis_aligned_box (pts , box_min , box_max ):
13
+ return torch .all (torch .cat ([pts >= box_min , pts <= box_max ], dim = 1 ), dim = 1 )
14
+
15
+ @nb .jit (nopython = True )
16
+ def bbox_intersection_batch (bounds , rays_o , rays_d ):
17
+ N_rays = rays_o .shape [0 ]
18
+ all_hit = np .empty ((N_rays ))
19
+ all_near = np .empty ((N_rays ))
20
+ all_far = np .empty ((N_rays ))
21
+ for idx , (o , d ) in enumerate (zip (rays_o , rays_d )):
22
+ hit , near , far = bbox_intersection (bounds , o , d )
23
+ # if hit == True:
24
+ # print("hit", hit)
25
+ all_hit [idx ] = hit
26
+ all_near [idx ] = near
27
+ all_far [idx ] = far
28
+ # return (h*w), (h*w, 3), (h*w, 3)
29
+ return all_hit , all_near , all_far
30
+
31
+ @nb .jit (nopython = True )
32
+ def bbox_intersection (bounds , orig , dir ):
33
+ # FIXME: currently, it is not working properly if the ray origin is inside the bounding box
34
+ # https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
35
+ # handle divide by zero
36
+ dir [dir == 0 ] = 1.0e-14
37
+ invdir = 1 / dir
38
+ sign = (invdir < 0 ).astype (np .int64 )
39
+
40
+ tmin = (bounds [sign [0 ]][0 ] - orig [0 ]) * invdir [0 ]
41
+ tmax = (bounds [1 - sign [0 ]][0 ] - orig [0 ]) * invdir [0 ]
42
+
43
+ tymin = (bounds [sign [1 ]][1 ] - orig [1 ]) * invdir [1 ]
44
+ tymax = (bounds [1 - sign [1 ]][1 ] - orig [1 ]) * invdir [1 ]
45
+
46
+ if tmin > tymax or tymin > tmax :
47
+ return False , 0 , 0
48
+ if tymin > tmin :
49
+ tmin = tymin
50
+ if tymax < tmax :
51
+ tmax = tymax
52
+
53
+ tzmin = (bounds [sign [2 ]][2 ] - orig [2 ]) * invdir [2 ]
54
+ tzmax = (bounds [1 - sign [2 ]][2 ] - orig [2 ]) * invdir [2 ]
55
+
56
+ if tmin > tzmax or tzmin > tmax :
57
+ return False , 0 , 0
58
+ if tzmin > tmin :
59
+ tmin = tzmin
60
+ if tzmax < tmax :
61
+ tmax = tzmax
62
+ # additionally, when the orig is inside the box, we return False
63
+ if tmin < 0 or tmax < 0 :
64
+ return False , 0 , 0
65
+ return True , tmin , tmax
66
+
67
+ def homogenise_torch (p ):
68
+ _1 = torch .ones_like (p [:, [0 ]])
69
+ return torch .cat ([p , _1 ], dim = - 1 )
70
+
71
+ def get_ray_directions (H , W , focal ):
72
+ """
73
+ Get ray directions for all pixels in camera coordinate.
74
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
75
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
76
+
77
+ Inputs:
78
+ H, W, focal: image height, width and focal length
79
+
80
+ Outputs:
81
+ directions: (H, W, 3), the direction of the rays in camera coordinate
82
+ """
83
+ grid = create_meshgrid (H , W , normalized_coordinates = False )[0 ]
84
+ i , j = grid .unbind (- 1 )
85
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
86
+ # see https://github.com/bmild/nerf/issues/24
87
+ directions = \
88
+ torch .stack ([(i - W / 2 )/ focal , - (j - H / 2 )/ focal , - torch .ones_like (i )], - 1 ) # (H, W, 3)
89
+
90
+ return directions
91
+
92
+
93
+ def get_rays_background (directions , c2w , coords ):
94
+ """
95
+ Get ray origin and normalized directions in world coordinate for all pixels in one image.
96
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
97
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
98
+
99
+ Inputs:
100
+ directions: (H, W, 3) precomputed ray directions in camera coordinate
101
+ c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
102
+
103
+ Outputs:
104
+ rays_o: (H*W, 3), the origin of the rays in world coordinate
105
+ rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
106
+ """
107
+ # Rotate ray directions from camera coordinate to the world coordinate
108
+ rays_d = directions @ c2w [:, :3 ].T # (H, W, 3)
109
+ rays_d /= torch .norm (rays_d , dim = - 1 , keepdim = True )
110
+ # The origin of all rays is the camera origin in world coordinate
111
+ rays_o = c2w [:, 3 ].expand (rays_d .shape ) # (H, W, 3)
112
+
113
+ rays_o = rays_o [coords [:, 0 ], coords [:, 1 ]]
114
+ rays_d = rays_d [coords [:, 0 ], coords [:, 1 ]]
115
+
116
+ return rays_o , rays_d
117
+
118
+ def get_rays (directions , c2w , output_view_dirs = False , output_radii = False ):
119
+ """
120
+ Get ray origin and normalized directions in world coordinate for all pixels in one image.
121
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
122
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
123
+
124
+ Inputs:
125
+ directions: (H, W, 3) precomputed ray directions in camera coordinate
126
+ c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
127
+
128
+ Outputs:
129
+ rays_o: (H*W, 3), the origin of the rays in world coordinate
130
+ rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
131
+ """
132
+ # Rotate ray directions from camera coordinate to the world coordinate
133
+ rays_d = directions @ c2w [:, :3 ].T # (H, W, 3)
134
+ #rays_d /= torch.norm(rays_d, dim=-1, keepdim=True)
135
+ # The origin of all rays is the camera origin in world coordinate
136
+ rays_o = c2w [:, 3 ].expand (rays_d .shape ) # (H, W, 3)
137
+
138
+ if output_radii :
139
+ rays_d_orig = directions @ c2w [:, :3 ].T
140
+ dx = torch .sqrt (torch .sum ((rays_d_orig [:- 1 , :, :] - rays_d_orig [1 :, :, :]) ** 2 , dim = - 1 ))
141
+ dx = torch .cat ([dx , dx [- 2 :- 1 , :]], dim = 0 )
142
+ radius = dx [..., None ] * 2 / torch .sqrt (torch .tensor (12 , dtype = torch .int8 ))
143
+ radius = radius .reshape (- 1 )
144
+
145
+ if output_view_dirs :
146
+ viewdirs = rays_d
147
+ viewdirs /= torch .norm (viewdirs , dim = - 1 , keepdim = True )
148
+ rays_d = rays_d .view (- 1 , 3 )
149
+ rays_o = rays_o .view (- 1 , 3 )
150
+ viewdirs = viewdirs .view (- 1 , 3 )
151
+ if output_radii :
152
+ return rays_o , viewdirs , rays_d , radius
153
+ else :
154
+ return rays_o , viewdirs , rays_d
155
+ else :
156
+ rays_d /= torch .norm (rays_d , dim = - 1 , keepdim = True )
157
+ rays_d = rays_d .view (- 1 , 3 )
158
+ rays_o = rays_o .view (- 1 , 3 )
159
+ return rays_o , rays_d
160
+
161
+
162
+ def transform_rays_camera (rays_o , rays_d , c2w ):
163
+ """
164
+ Get ray origin and normalized directions in world coordinate for all pixels in one image.
165
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
166
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
167
+
168
+ Inputs:
169
+ directions: (H, W, 3) precomputed ray directions in camera coordinate
170
+ c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
171
+
172
+ Outputs:
173
+ rays_o: (H*W, 3), the origin of the rays in world coordinate
174
+ rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
175
+ """
176
+ # Rotate ray directions from camera coordinate to the world coordinate
177
+ rays_d = rays_d @ c2w [:, :3 ].T # (H, W, 3)
178
+ rays_d /= torch .norm (rays_d , dim = - 1 , keepdim = True )
179
+ # The origin of all rays is the camera origin in world coordinate
180
+ rays_o = c2w [:, 3 ].expand (rays_d .shape ) + rays_o # (H, W, 3)
181
+
182
+ rays_d = rays_d .view (- 1 , 3 )
183
+ rays_o = rays_o .view (- 1 , 3 )
184
+
185
+ return rays_o , rays_d
186
+
187
+ def get_ndc_rays (H , W , focal , near , rays_o , rays_d ):
188
+ """
189
+ Transform rays from world coordinate to NDC.
190
+ NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis.
191
+ For detailed derivation, please see:
192
+ http://www.songho.ca/opengl/gl_projectionmatrix.html
193
+ https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf
194
+
195
+ In practice, use NDC "if and only if" the scene is unbounded (has a large depth).
196
+ See https://github.com/bmild/nerf/issues/18
197
+
198
+ Inputs:
199
+ H, W, focal: image height, width and focal length
200
+ near: (N_rays) or float, the depths of the near plane
201
+ rays_o: (N_rays, 3), the origin of the rays in world coordinate
202
+ rays_d: (N_rays, 3), the direction of the rays in world coordinate
203
+
204
+ Outputs:
205
+ rays_o: (N_rays, 3), the origin of the rays in NDC
206
+ rays_d: (N_rays, 3), the direction of the rays in NDC
207
+ """
208
+ # Shift ray origins to near plane
209
+ t = - (near + rays_o [...,2 ]) / rays_d [...,2 ]
210
+ rays_o = rays_o + t [...,None ] * rays_d
211
+
212
+ # Store some intermediate homogeneous results
213
+ ox_oz = rays_o [...,0 ] / rays_o [...,2 ]
214
+ oy_oz = rays_o [...,1 ] / rays_o [...,2 ]
215
+
216
+ # Projection
217
+ o0 = - 1. / (W / (2. * focal )) * ox_oz
218
+ o1 = - 1. / (H / (2. * focal )) * oy_oz
219
+ o2 = 1. + 2. * near / rays_o [...,2 ]
220
+
221
+ d0 = - 1. / (W / (2. * focal )) * (rays_d [...,0 ]/ rays_d [...,2 ] - ox_oz )
222
+ d1 = - 1. / (H / (2. * focal )) * (rays_d [...,1 ]/ rays_d [...,2 ] - oy_oz )
223
+ d2 = 1 - o2
224
+
225
+ rays_o = torch .stack ([o0 , o1 , o2 ], - 1 ) # (B, 3)
226
+ rays_d = torch .stack ([d0 , d1 , d2 ], - 1 ) # (B, 3)
227
+
228
+ return rays_o , rays_d
229
+
230
+ def world_to_ndc (rotated_pcds , W , H , focal , near ):
231
+
232
+ ox_oz = rotated_pcds [...,0 ] / rotated_pcds [...,2 ]
233
+ oy_oz = rotated_pcds [...,1 ] / rotated_pcds [...,2 ]
234
+
235
+ # Projection
236
+ o0 = - 1. / (W / (2. * focal )) * ox_oz
237
+ o1 = - 1. / (H / (2. * focal )) * oy_oz
238
+ o2 = 1. + 2. * near / rotated_pcds [...,2 ]
239
+
240
+ # oz_ox = rotated_pcds[...,2] / rotated_pcds[...,0]
241
+ # oz_oy = rotated_pcds[...,2] / rotated_pcds[...,1]
242
+ # # Projection
243
+ # o0 = -(W/(2.*focal)) * oz_ox
244
+ # o1 = -(H/(2.*focal)) * oz_oy
245
+ # o2 = 1. + 2. * near / rotated_pcds[...,2]
246
+ # print("o1.shape", o1.shape)
247
+ rotated_pcd = np .concatenate ((np .expand_dims (o0 , axis = - 1 ), np .expand_dims (o1 , axis = - 1 ), np .expand_dims (o2 , axis = - 1 )), - 1 )
248
+ return rotated_pcd
249
+
250
+
251
+
252
+ def get_rays_segmented (masks , class_ids , rays_o , rays_d , W , H , N_rays ):
253
+ seg_mask = np .zeros ([H , W ])
254
+ for i in range (len (class_ids )):
255
+ seg_mask [masks [:,:,i ] > 0 ] = np .array (class_ids )[i ]
256
+ # print("classIds", class_ids)
257
+ # print("seg masks", (seg_mask>0).flatten().shape, (seg_mask>0).shape)
258
+ # print("(seg_mask>0).flatten()", np.count_nonzero((seg_mask>0).flatten()))
259
+ # print("seg mask ", np.count_nonzero(seg_mask))
260
+ # print("(seg_mask>0).flatten()", (seg_mask>0).flatten())
261
+ # print("seg mask", seg_mask)
262
+ # plt.imshow(seg_mask)
263
+ # plt.show()
264
+
265
+ rays_rgb_obj = []
266
+ rays_rgb_obj_dir = []
267
+ class_ids .sort ()
268
+
269
+ select_inds = []
270
+ for i in range (len (class_ids )):
271
+ rays_on_obj = np .where (seg_mask .flatten () == class_ids [i ])[0 ]
272
+ print ("rays_on_obj" , rays_on_obj .shape )
273
+ rays_on_obj = rays_on_obj [np .random .choice (rays_on_obj .shape [0 ], N_rays )]
274
+ select_inds .append (rays_on_obj )
275
+ obj_mask = np .zeros (len (rays_o ), np .bool )
276
+ obj_mask [rays_on_obj ] = 1
277
+ rays_rgb_obj .append (rays_o [obj_mask ])
278
+ rays_rgb_obj_dir .append (rays_d [obj_mask ])
279
+ select_inds = np .concatenate (select_inds , axis = 0 )
280
+ obj_mask = np .zeros (len (rays_o ), np .bool )
281
+ obj_mask [select_inds ] = 1
282
+
283
+ # for i in range(len(class_ids)):
284
+ # rays_on_obj = np.where(seg_mask.flatten() == class_ids[i])[0]
285
+ # obj_mask = np.zeros(len(rays_o), np.bool)
286
+ # obj_mask[rays_on_obj] = 1
287
+ # rays_rgb_obj.append(rays_o[obj_mask])
288
+ # rays_rgb_obj_dir.append(rays_d[obj_mask])
289
+
290
+
291
+ # N_rays = min(N_rays, H * W)
292
+ # select_inds = []
293
+ # for b in range(len(class_ids)):
294
+ # fg_inds = np.nonzero(seg_mask.flatten() == class_ids[b])
295
+ # fg_inds = np.transpose(np.asarray(fg_inds))
296
+ # fg_inds = fg_inds[np.random.choice(fg_inds.shape[0], N_rays)]
297
+ # select_inds.append(fg_inds)
298
+ # select_inds = np.concatenate(select_inds, axis=0)
299
+ # j, i = select_inds[..., 0], select_inds[..., 1]
300
+
301
+ # select_inds = j * W + i
302
+
303
+ return rays_rgb_obj , rays_rgb_obj_dir , class_ids , (seg_mask > 0 ).flatten ()
304
+
305
+
306
+ def convert_pose_PD_to_NeRF (C2W ):
307
+
308
+ flip_axes = np .array ([[1 ,0 ,0 ,0 ],
309
+ [0 ,0 ,- 1 ,0 ],
310
+ [0 ,1 ,0 ,0 ],
311
+ [0 ,0 ,0 ,1 ]])
312
+ C2W = np .matmul (C2W , flip_axes )
313
+ return C2W
314
+
315
+ def get_rays_mvs (H , W , focal , c2w ):
316
+ ys , xs = torch .meshgrid (torch .linspace (0 , H - 1 , H ), torch .linspace (0 , W - 1 , W )) # pytorch's meshgrid has indexing='ij'
317
+ ys , xs = ys .reshape (- 1 ), xs .reshape (- 1 )
318
+
319
+ dirs = torch .stack ([(xs - W / 2 )/ focal , (ys - H / 2 )/ focal , torch .ones_like (xs )], - 1 ) # use 1 instead of -1
320
+ rays_d = dirs @ c2w [:3 ,:3 ].t () # dot product, equals to: [c2w.dot(dir) for dir in dirs]
321
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
322
+ rays_o = c2w [:, 3 ].expand (rays_d .shape ) # (H, W, 3)
323
+ rays_d = rays_d .view (- 1 , 3 )
324
+ rays_o = rays_o .view (- 1 , 3 )
325
+ return rays_o , rays_d
0 commit comments