55# LICENSE file in the root directory of this source tree.
66
77import warnings
8- from typing import Optional
8+ from typing import Optional , Tuple , Union
99
1010import torch
1111from pytorch3d .common .compat import meshgrid_ij
12+ from pytorch3d .ops import padded_to_packed
1213from pytorch3d .renderer .cameras import CamerasBase
13- from pytorch3d .renderer .implicit .utils import RayBundle
14+ from pytorch3d .renderer .implicit .utils import HeterogeneousRayBundle , RayBundle
1415from torch .nn import functional as F
1516
1617
@@ -73,6 +74,7 @@ def __init__(
7374 min_depth : float ,
7475 max_depth : float ,
7576 n_rays_per_image : Optional [int ] = None ,
77+ n_rays_total : Optional [int ] = None ,
7678 unit_directions : bool = False ,
7779 stratified_sampling : bool = False ,
7880 ) -> None :
@@ -88,6 +90,11 @@ def __init__(
8890 min_depth: The minimum depth of a ray-point.
8991 max_depth: The maximum depth of a ray-point.
9092 n_rays_per_image: If given, this amount of rays are sampled from the grid.
93+ n_rays_total: How many rays in total to sample from the cameras provided. The result
94+ is as if `n_rays_total` cameras were sampled with replacement from the
95+ cameras provided and for every camera one ray was sampled. If set, this disables
96+ `n_rays_per_image` and returns the HeterogeneousRayBundle with
97+ batch_size=n_rays_total.
9198 unit_directions: whether to normalize direction vectors in ray bundle.
9299 stratified_sampling: if True, performs stratified random sampling
93100 along the ray; otherwise takes ray points at deterministic offsets.
@@ -97,6 +104,7 @@ def __init__(
97104 self ._min_depth = min_depth
98105 self ._max_depth = max_depth
99106 self ._n_rays_per_image = n_rays_per_image
107+ self ._n_rays_total = n_rays_total
100108 self ._unit_directions = unit_directions
101109 self ._stratified_sampling = stratified_sampling
102110
@@ -125,8 +133,9 @@ def forward(
125133 n_rays_per_image : Optional [int ] = None ,
126134 n_pts_per_ray : Optional [int ] = None ,
127135 stratified_sampling : Optional [bool ] = None ,
136+ n_rays_total : Optional [int ] = None ,
128137 ** kwargs ,
129- ) -> RayBundle :
138+ ) -> Union [ RayBundle , HeterogeneousRayBundle ] :
130139 """
131140 Args:
132141 cameras: A batch of `batch_size` cameras from which the rays are emitted.
@@ -138,8 +147,15 @@ def forward(
138147 n_pts_per_ray: The number of points sampled along each ray.
139148 stratified_sampling: if set, overrides stratified_sampling provided
140149 in __init__.
150+ n_rays_total: How many rays in total to sample from the cameras provided. The result
151+ is as if `n_rays_total_training` cameras were sampled with replacement from the
152+ cameras provided and for every camera one ray was sampled. If set, this disables
153+ `n_rays_per_image` and returns the HeterogeneousRayBundle with
154+ batch_size=n_rays_total.
141155 Returns:
142- A named tuple RayBundle with the following fields:
156+ A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
157+ following fields:
158+
143159 origins: A tensor of shape
144160 `(batch_size, s1, s2, 3)`
145161 denoting the locations of ray origins in the world coordinates.
@@ -153,23 +169,56 @@ def forward(
153169 `(batch_size, s1, s2, 2)`
154170 containing the 2D image coordinates of each ray or,
155171 if mask is given, `(batch_size, n, 1, 2)`
156- Here `s1, s2` refer to spatial dimensions. Unless the mask is
157- given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
158- where `n` is `n_rays_per_image` if provided, otherwise the minimum
159- cardinality of the mask in the batch.
172+ Here `s1, s2` refer to spatial dimensions.
173+ `(s1, s2)` refer to (highest priority first):
174+ - `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
175+ - `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
176+ - `(n, 1)` where n is the minimum cardinality of the mask
177+ in the batch if `mask` is provided
178+ - `(image_height, image_width)` if nothing from above is satisfied
179+
180+ `HeterogeneousRayBundle` has additional members:
181+ - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
182+ cameras. It represents unique ids of sampled cameras.
183+ - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
184+ cameras. Represents how many times each camera from `camera_ids` was sampled
185+
186+ `HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
187+ is returned.
160188 """
189+ n_rays_total = n_rays_total or self ._n_rays_total
190+ n_rays_per_image = n_rays_per_image or self ._n_rays_per_image
191+ assert (n_rays_total is None ) or (
192+ n_rays_per_image is None
193+ ), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
194+ if n_rays_total :
195+ (
196+ cameras ,
197+ mask ,
198+ camera_ids , # unique ids of sampled cameras
199+ camera_counts , # number of times unique camera id was sampled
200+ # `n_rays_per_image` is equal to the max number of times a simgle camera
201+ # was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
202+ # and then discard the unneeded rays.
203+ # pyre-ignore[9]
204+ n_rays_per_image ,
205+ ) = _sample_cameras_and_masks (n_rays_total , cameras , mask )
206+ else :
207+ camera_ids = torch .range (0 , len (cameras ), dtype = torch .long )
208+
161209 batch_size = cameras .R .shape [0 ]
162210 device = cameras .device
163211
164212 # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
165213 xy_grid = self ._xy_grid .to (device ).expand (batch_size , - 1 , - 1 , - 1 )
166214
167- num_rays = n_rays_per_image or self ._n_rays_per_image
168- if mask is not None and num_rays is None :
215+ if mask is not None and n_rays_per_image is None :
169216 # if num rays not given, sample according to the smallest mask
170- num_rays = num_rays or mask .sum (dim = (1 , 2 )).min ().int ().item ()
217+ n_rays_per_image = (
218+ n_rays_per_image or mask .sum (dim = (1 , 2 )).min ().int ().item ()
219+ )
171220
172- if num_rays is not None :
221+ if n_rays_per_image is not None :
173222 if mask is not None :
174223 assert mask .shape == xy_grid .shape [:3 ]
175224 weights = mask .reshape (batch_size , - 1 )
@@ -181,7 +230,9 @@ def forward(
181230 weights = xy_grid .new_ones (batch_size , width * height )
182231 # pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
183232 # float, int]`.
184- rays_idx = _safe_multinomial (weights , num_rays )[..., None ].expand (- 1 , - 1 , 2 )
233+ rays_idx = _safe_multinomial (weights , n_rays_per_image )[..., None ].expand (
234+ - 1 , - 1 , 2
235+ )
185236
186237 xy_grid = torch .gather (xy_grid .reshape (batch_size , - 1 , 2 ), 1 , rays_idx )[
187238 :, :, None
@@ -198,7 +249,7 @@ def forward(
198249 else self ._stratified_sampling
199250 )
200251
201- return _xy_to_ray_bundle (
252+ ray_bundle = _xy_to_ray_bundle (
202253 cameras ,
203254 xy_grid ,
204255 min_depth ,
@@ -208,6 +259,13 @@ def forward(
208259 stratified_sampling ,
209260 )
210261
262+ return (
263+ # pyre-ignore[61]
264+ _pack_ray_bundle (ray_bundle , camera_ids , camera_counts )
265+ if n_rays_total
266+ else ray_bundle
267+ )
268+
211269
212270class NDCMultinomialRaysampler (MultinomialRaysampler ):
213271 """
@@ -231,6 +289,7 @@ def __init__(
231289 min_depth : float ,
232290 max_depth : float ,
233291 n_rays_per_image : Optional [int ] = None ,
292+ n_rays_total : Optional [int ] = None ,
234293 unit_directions : bool = False ,
235294 stratified_sampling : bool = False ,
236295 ) -> None :
@@ -254,6 +313,7 @@ def __init__(
254313 min_depth = min_depth ,
255314 max_depth = max_depth ,
256315 n_rays_per_image = n_rays_per_image ,
316+ n_rays_total = n_rays_total ,
257317 unit_directions = unit_directions ,
258318 stratified_sampling = stratified_sampling ,
259319 )
@@ -281,6 +341,7 @@ def __init__(
281341 min_depth : float ,
282342 max_depth : float ,
283343 * ,
344+ n_rays_total : Optional [int ] = None ,
284345 unit_directions : bool = False ,
285346 stratified_sampling : bool = False ,
286347 ) -> None :
@@ -294,6 +355,11 @@ def __init__(
294355 n_pts_per_ray: The number of points sampled along each ray.
295356 min_depth: The minimum depth of each ray-point.
296357 max_depth: The maximum depth of each ray-point.
358+ n_rays_total: How many rays in total to sample from the cameras provided. The result
359+ is as if `n_rays_total_training` cameras were sampled with replacement from the
360+ cameras provided and for every camera one ray was sampled. If set, this disables
361+ `n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
362+ batch_size=n_rays_total.
297363 unit_directions: whether to normalize direction vectors in ray bundle.
298364 stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
299365 bins for each ray; otherwise takes n_pts_per_ray deterministic points
@@ -308,6 +374,7 @@ def __init__(
308374 self ._n_pts_per_ray = n_pts_per_ray
309375 self ._min_depth = min_depth
310376 self ._max_depth = max_depth
377+ self ._n_rays_total = n_rays_total
311378 self ._unit_directions = unit_directions
312379 self ._stratified_sampling = stratified_sampling
313380
@@ -317,15 +384,16 @@ def forward(
317384 * ,
318385 stratified_sampling : Optional [bool ] = None ,
319386 ** kwargs ,
320- ) -> RayBundle :
387+ ) -> Union [ RayBundle , HeterogeneousRayBundle ] :
321388 """
322389 Args:
323390 cameras: A batch of `batch_size` cameras from which the rays are emitted.
324391 stratified_sampling: if set, overrides stratified_sampling provided
325392 in __init__.
326-
327393 Returns:
328- A named tuple RayBundle with the following fields:
394+ A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
395+ following fields:
396+
329397 origins: A tensor of shape
330398 `(batch_size, n_rays_per_image, 3)`
331399 denoting the locations of ray origins in the world coordinates.
@@ -338,7 +406,31 @@ def forward(
338406 xys: A tensor of shape
339407 `(batch_size, n_rays_per_image, 2)`
340408 containing the 2D image coordinates of each ray.
409+ If `n_rays_total` is provided `batch_size=n_rays_total`and
410+ `n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
411+ is returned.
412+
413+ `HeterogeneousRayBundle` has additional members:
414+ - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
415+ cameras. It represents unique ids of sampled cameras.
416+ - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
417+ cameras. Represents how many times each camera from `camera_ids` was sampled
341418 """
419+ assert (self ._n_rays_total is None ) or (
420+ self ._n_rays_per_image is None
421+ ), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
422+
423+ if self ._n_rays_total :
424+ (
425+ cameras ,
426+ _ ,
427+ camera_ids ,
428+ camera_counts ,
429+ n_rays_per_image ,
430+ ) = _sample_cameras_and_masks (self ._n_rays_total , cameras , None )
431+ else :
432+ camera_ids = torch .range (0 , len (cameras ), dtype = torch .long )
433+ n_rays_per_image = self ._n_rays_per_image
342434
343435 batch_size = cameras .R .shape [0 ]
344436
@@ -349,7 +441,7 @@ def forward(
349441 rays_xy = torch .cat (
350442 [
351443 torch .rand (
352- size = (batch_size , self . _n_rays_per_image , 1 ),
444+ size = (batch_size , n_rays_per_image , 1 ),
353445 dtype = torch .float32 ,
354446 device = device ,
355447 )
@@ -369,7 +461,7 @@ def forward(
369461 else self ._stratified_sampling
370462 )
371463
372- return _xy_to_ray_bundle (
464+ ray_bundle = _xy_to_ray_bundle (
373465 cameras ,
374466 rays_xy ,
375467 self ._min_depth ,
@@ -379,6 +471,13 @@ def forward(
379471 stratified_sampling ,
380472 )
381473
474+ return (
475+ # pyre-ignore[61]
476+ _pack_ray_bundle (ray_bundle , camera_ids , camera_counts )
477+ if self ._n_rays_total
478+ else ray_bundle
479+ )
480+
382481
383482# Settings for backwards compatibility
384483def GridRaysampler (
@@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
602701 # Samples in those intervals.
603702 jiggled = lower + (upper - lower ) * torch .rand_like (lower )
604703 return jiggled
704+
705+
706+ def _sample_cameras_and_masks (
707+ n_samples : int , cameras : CamerasBase , mask : Optional [torch .Tensor ] = None
708+ ) -> Tuple [
709+ CamerasBase , Optional [torch .Tensor ], torch .Tensor , torch .Tensor , torch .Tensor
710+ ]:
711+ """
712+ Samples n_rays_total cameras and masks and returns them in a form
713+ (camera_idx, count), where count represents number of times the same camera
714+ has been sampled.
715+
716+ Args:
717+ n_samples: how many camera and mask pairs to sample
718+ cameras: A batch of `batch_size` cameras from which the rays are emitted.
719+ mask: Optional. Should be of size (batch_size, image_height, image_width).
720+ Returns:
721+ tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
722+ number_of_times_each_sampled_camera_has_been_sampled,
723+ max_number_of_times_camera_has_been_sampled,
724+ )
725+ """
726+ sampled_ids = torch .randint (
727+ 0 ,
728+ len (cameras ),
729+ size = (n_samples ,),
730+ dtype = torch .long ,
731+ )
732+ unique_ids , counts = torch .unique (sampled_ids , return_counts = True )
733+ return (
734+ cameras [unique_ids ],
735+ mask [unique_ids ] if mask is not None else None ,
736+ unique_ids ,
737+ counts ,
738+ torch .max (counts ),
739+ )
740+
741+
742+ def _pack_ray_bundle (
743+ ray_bundle : RayBundle , camera_ids : torch .Tensor , camera_counts : torch .Tensor
744+ ) -> HeterogeneousRayBundle :
745+ """
746+ Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
747+ [total_num_rays, 1, ...]
748+
749+ Args:
750+ ray_bundle: A ray_bundle to pack
751+ camera_ids: Unique ids of cameras that were sampled
752+ camera_counts: how many of which camera to pack, each count coresponds to
753+ one 'row' of the ray_bundle and says how many rays wll be taken
754+ from it and packed.
755+ Returns:
756+ HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
757+ """
758+ camera_counts = camera_counts .to (ray_bundle .origins .device )
759+ cumsum = torch .cumsum (camera_counts , dim = 0 , dtype = torch .long )
760+ first_idxs = torch .cat (
761+ (camera_counts .new_zeros ((1 ,), dtype = torch .long ), cumsum [:- 1 ])
762+ )
763+ num_inputs = int (camera_counts .sum ())
764+
765+ return HeterogeneousRayBundle (
766+ origins = padded_to_packed (ray_bundle .origins , first_idxs , num_inputs )[:, None ],
767+ directions = padded_to_packed (ray_bundle .directions , first_idxs , num_inputs )[
768+ :, None
769+ ],
770+ lengths = padded_to_packed (ray_bundle .lengths , first_idxs , num_inputs )[:, None ],
771+ xys = padded_to_packed (ray_bundle .xys , first_idxs , num_inputs )[:, None ],
772+ camera_ids = camera_ids ,
773+ camera_counts = camera_counts ,
774+ )
0 commit comments