99from typing import Any , Dict , Optional
1010
1111import torch
12+ from pytorch3d .implicitron .models .renderer .ray_sampler import ImplicitronRayBundle
1213from pytorch3d .implicitron .tools import metric_utils as utils
1314from pytorch3d .implicitron .tools .config import registry , ReplaceableBase
15+ from pytorch3d .ops import packed_to_padded , padded_to_packed
1416from pytorch3d .renderer import utils as rend_utils
1517
1618from .renderer .base import RendererOutput
@@ -60,7 +62,7 @@ def __post_init__(self) -> None:
6062 def forward (
6163 self ,
6264 raymarched : RendererOutput ,
63- xys : torch . Tensor ,
65+ ray_bundle : ImplicitronRayBundle ,
6466 image_rgb : Optional [torch .Tensor ] = None ,
6567 depth_map : Optional [torch .Tensor ] = None ,
6668 fg_probability : Optional [torch .Tensor ] = None ,
@@ -79,10 +81,8 @@ def forward(
7981 names of the output metrics `metric_name_i` with their corresponding
8082 values `metric_value_i` represented as 0-dimensional float tensors.
8183 raymarched: Output of the renderer.
82- xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
83- the predictions are defined. All ground truth inputs are sampled at
84- these locations in order to extract values that correspond to the
85- predictions.
84+ ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
85+ object
8686 image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
8787 values.
8888 depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase):
141141 def forward (
142142 self ,
143143 raymarched : RendererOutput ,
144- xys : torch . Tensor ,
144+ ray_bundle : ImplicitronRayBundle ,
145145 image_rgb : Optional [torch .Tensor ] = None ,
146146 depth_map : Optional [torch .Tensor ] = None ,
147147 fg_probability : Optional [torch .Tensor ] = None ,
@@ -165,10 +165,8 @@ def forward(
165165 input 3D coordinates used to compute the eikonal loss.
166166 raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
167167 containing a `Hg x Wg x Dg` voxel grid of density values.
168- xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
169- the predictions are defined. All ground truth inputs are sampled at
170- these locations in order to extract values that correspond to the
171- predictions.
168+ ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
169+ object
172170 image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
173171 values.
174172 depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@@ -209,7 +207,7 @@ def forward(
209207 """
210208 metrics = self ._calculate_stage (
211209 raymarched ,
212- xys ,
210+ ray_bundle ,
213211 image_rgb ,
214212 depth_map ,
215213 fg_probability ,
@@ -221,7 +219,7 @@ def forward(
221219 metrics .update (
222220 self (
223221 raymarched .prev_stage ,
224- xys ,
222+ ray_bundle ,
225223 image_rgb ,
226224 depth_map ,
227225 fg_probability ,
@@ -235,7 +233,7 @@ def forward(
235233 def _calculate_stage (
236234 self ,
237235 raymarched : RendererOutput ,
238- xys : torch . Tensor ,
236+ ray_bundle : ImplicitronRayBundle ,
239237 image_rgb : Optional [torch .Tensor ] = None ,
240238 depth_map : Optional [torch .Tensor ] = None ,
241239 fg_probability : Optional [torch .Tensor ] = None ,
@@ -253,6 +251,27 @@ def _calculate_stage(
253251 _reshape_nongrid_var (x )
254252 for x in [raymarched .features , raymarched .masks , raymarched .depths ]
255253 ]
254+ xys = ray_bundle .xys
255+
256+ # If ray_bundle is packed than we can sample images in padded state to lower
257+ # memory requirements. Instead of having one image for every element in
258+ # ray_bundle we can than have one image per unique sampled camera.
259+ if ray_bundle .is_packed ():
260+ # pyre-ignore[6]
261+ cumsum = torch .cumsum (ray_bundle .camera_counts , dim = 0 , dtype = torch .long )
262+ first_idxs = torch .cat (
263+ (
264+ # pyre-ignore[16]
265+ ray_bundle .camera_counts .new_zeros ((1 ,), dtype = torch .long ),
266+ cumsum [:- 1 ],
267+ )
268+ )
269+ # pyre-ignore[16]
270+ num_inputs = int (ray_bundle .camera_counts .sum ())
271+ # pyre-ignore[6]
272+ max_size = int (torch .max (ray_bundle .camera_counts ))
273+ xys = packed_to_padded (xys , first_idxs , max_size )
274+
256275 # reshape the sampling grid as well
257276 # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
258277 # now that we use rend_utils.ndc_grid_sample
@@ -262,7 +281,20 @@ def _calculate_stage(
262281 def sample (tensor , mode ):
263282 if tensor is None :
264283 return tensor
265- return rend_utils .ndc_grid_sample (tensor , xys , mode = mode )
284+ if ray_bundle .is_packed ():
285+ # select images that corespond to sampled cameras if raybundle is packed
286+ tensor = tensor [ray_bundle .camera_ids ]
287+ result = rend_utils .ndc_grid_sample (tensor , xys , mode = mode )
288+ if ray_bundle .is_packed ():
289+ # Images after sampling are in a form [batch, 3, max_num_rays, 1],
290+ # packed_to_padded combines first two dimensions so we need to swap 1st
291+ # and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
292+ # (we use keepdim=True).
293+ result = result .transpose (1 , 2 )
294+ result = padded_to_packed (result , first_idxs , num_inputs )[:, None ]
295+ result = result .transpose (1 , 2 )
296+
297+ return result
266298
267299 # eval all results in this size
268300 image_rgb = sample (image_rgb , mode = "bilinear" )
0 commit comments