|
24 | 24 | import torch |
25 | 25 | import torch_xla2.ops.jaten # Import to load torch_xla2 ops |
26 | 26 | import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops |
| 27 | +import numpy as np |
27 | 28 |
|
28 | 29 | LoweringContext = context.LoweringContext |
29 | 30 |
|
@@ -71,8 +72,7 @@ def lower_by_torch_xla2(op): |
71 | 72 | lower_by_torch_xla2(torch.ops.aten._local_scalar_dense) |
72 | 73 | lower_by_torch_xla2(torch.ops.aten._local_scalar_dense) |
73 | 74 | lower_by_torch_xla2(torch.ops.aten._log_softmax) |
74 | | -lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit) |
75 | | -lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training) |
| 75 | +lower_by_torch_xla2(torch.ops.aten.native_batch_norm) |
76 | 76 | lower_by_torch_xla2(torch.ops.aten._pdist_forward) |
77 | 77 | lower_by_torch_xla2(torch.ops.aten._softmax) |
78 | 78 | lower_by_torch_xla2(torch.ops.aten._unsafe_index) |
@@ -158,24 +158,20 @@ def lower_by_torch_xla2(op): |
158 | 158 | lower_by_torch_xla2(torch.ops.aten.logical_or) |
159 | 159 | lower_by_torch_xla2(torch.ops.aten.logical_xor) |
160 | 160 | lower_by_torch_xla2(torch.ops.aten.max) |
161 | | -lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices) |
162 | 161 | lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward) |
163 | 162 | lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward) |
164 | | -lower_by_torch_xla2(torch.ops.aten.max_pool3d_with_indices) |
165 | 163 | lower_by_torch_xla2(torch.ops.aten.maximum) |
166 | 164 | lower_by_torch_xla2(torch.ops.aten.mean) |
167 | 165 | lower_by_torch_xla2(torch.ops.aten.min) |
168 | 166 | lower_by_torch_xla2(torch.ops.aten.minimum) |
169 | 167 | lower_by_torch_xla2(torch.ops.aten.mm) |
170 | | -lower_by_torch_xla2(torch.ops.aten.native_batch_norm) |
171 | 168 | lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward) |
172 | 169 | lower_by_torch_xla2(torch.ops.aten.ne) |
173 | 170 | lower_by_torch_xla2(torch.ops.aten.neg) |
174 | 171 | lower_by_torch_xla2(torch.ops.aten.nonzero) |
175 | 172 | lower_by_torch_xla2(torch.ops.aten.outer) |
176 | 173 | lower_by_torch_xla2(torch.ops.aten.permute) |
177 | 174 | lower_by_torch_xla2(torch.ops.aten.permute_copy) |
178 | | -lower_by_torch_xla2(torch.ops.aten.pixel_shuffle) |
179 | 175 | lower_by_torch_xla2(torch.ops.aten.pow) |
180 | 176 | lower_by_torch_xla2(torch.ops.aten.prod) |
181 | 177 | lower_by_torch_xla2(torch.ops.aten.reciprocal) |
@@ -240,6 +236,249 @@ def lower_by_torch_xla2(op): |
240 | 236 | lower_by_torch_xla2(torch.ops.prims.var) |
241 | 237 |
|
242 | 238 |
|
| 239 | +def _ceil_mode_padding( |
| 240 | + padding: list[int], |
| 241 | + input_shape: list[int], |
| 242 | + kernel_size: list[int], |
| 243 | + stride: list[int], |
| 244 | + dilation: list[int], |
| 245 | + ceil_mode: bool, |
| 246 | +): |
| 247 | + """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. |
| 248 | +
|
| 249 | + Additional high padding could be required when ceil mode is set. |
| 250 | + """ |
| 251 | + ceil_mode_padding = [] |
| 252 | + for i in range(len(padding)): |
| 253 | + left_padding = padding[i] |
| 254 | + right_padding = left_padding |
| 255 | + |
| 256 | + input_size = input_shape[2 + i] |
| 257 | + output_size_rem = ( |
| 258 | + input_size |
| 259 | + + 2 * left_padding |
| 260 | + - (kernel_size[i] - 1) * dilation[i] |
| 261 | + - 1 |
| 262 | + ) % stride[i] |
| 263 | + if ceil_mode and output_size_rem != 0: |
| 264 | + extra_padding = stride[i] - output_size_rem |
| 265 | + new_output_size = ( |
| 266 | + input_size |
| 267 | + + left_padding |
| 268 | + + right_padding |
| 269 | + + extra_padding |
| 270 | + - (kernel_size[i] - 1) * dilation[i] |
| 271 | + - 1 |
| 272 | + + stride[i] |
| 273 | + - 1 |
| 274 | + ) // stride[i] + 1 |
| 275 | + # Ensure that the last pooling starts inside the image. |
| 276 | + size_to_compare = input_size + left_padding |
| 277 | + |
| 278 | + if (new_output_size - 1) * stride[i] < size_to_compare: |
| 279 | + right_padding += extra_padding |
| 280 | + |
| 281 | + ceil_mode_padding.append((left_padding, right_padding)) |
| 282 | + return ceil_mode_padding |
| 283 | + |
| 284 | + |
| 285 | +def max_pool( |
| 286 | + inputs, |
| 287 | + kernel_size, |
| 288 | + strides=None, |
| 289 | + padding=0, |
| 290 | + dilation=1, |
| 291 | + ceil_mode=False, |
| 292 | + with_index=False, |
| 293 | +): |
| 294 | + num_spatial_dims = len(kernel_size) |
| 295 | + num_batch_dims = inputs.ndim - num_spatial_dims - 1 |
| 296 | + kernel_size_tup = tuple(kernel_size) |
| 297 | + # Default stride is kernel_size |
| 298 | + strides_tup = tuple(strides) if strides else kernel_size_tup |
| 299 | + if isinstance(padding, int): |
| 300 | + padding_list = [padding for _ in range(num_spatial_dims)] |
| 301 | + elif not padding: # padding can be [], meaning all zeros. |
| 302 | + padding_list = [0 for _ in range(num_spatial_dims)] |
| 303 | + else: |
| 304 | + padding_list = padding |
| 305 | + |
| 306 | + if isinstance(dilation, int): |
| 307 | + dilation_tup = tuple(dilation for _ in range(num_spatial_dims)) |
| 308 | + elif not dilation: |
| 309 | + dilation_tup = tuple(1 for _ in range(num_spatial_dims)) |
| 310 | + elif isinstance(dilation, list): |
| 311 | + dilation_tup = tuple(dilation) |
| 312 | + else: |
| 313 | + dilation_tup = dilation |
| 314 | + |
| 315 | + input_shape_for_ceil = inputs.shape |
| 316 | + if num_batch_dims == 0: |
| 317 | + input_shape_for_ceil = [1, *input_shape_for_ceil] |
| 318 | + padding_pairs = _ceil_mode_padding( |
| 319 | + padding_list, |
| 320 | + input_shape_for_ceil, |
| 321 | + kernel_size_tup, |
| 322 | + strides_tup, |
| 323 | + dilation_tup, |
| 324 | + ceil_mode, |
| 325 | + ) |
| 326 | + |
| 327 | + assert len(kernel_size_tup) == len(strides_tup), ( |
| 328 | + f"len({kernel_size_tup=}) must equal len({strides_tup=})" |
| 329 | + ) |
| 330 | + assert len(kernel_size_tup) == len(dilation_tup), ( |
| 331 | + f"len({kernel_size_tup=}) must equal len({dilation_tup=})" |
| 332 | + ) |
| 333 | + |
| 334 | + is_single_input = False |
| 335 | + if num_batch_dims == 0: |
| 336 | + inputs = inputs[None] |
| 337 | + is_single_input = True |
| 338 | + |
| 339 | + reduce_window_strides = (1,) * (inputs.ndim - num_spatial_dims) + strides_tup |
| 340 | + reduce_window_dims = (1,) * (inputs.ndim - num_spatial_dims) + kernel_size_tup |
| 341 | + reduce_window_dilation = ( |
| 342 | + 1, |
| 343 | + ) * (inputs.ndim - num_spatial_dims) + dilation_tup |
| 344 | + |
| 345 | + assert inputs.ndim == len( |
| 346 | + reduce_window_dims |
| 347 | + ), f"len({inputs.shape}) != len({reduce_window_dims})" |
| 348 | + if not isinstance(padding_pairs, str): |
| 349 | + padding_pairs_tup = tuple(padding_pairs) |
| 350 | + assert all([len(x) == 2 for x in padding_pairs_tup]), ( |
| 351 | + f"each entry in padding {padding_pairs_tup} must be length 2" |
| 352 | + ) |
| 353 | + padding_lax = ( |
| 354 | + ((0, 0),) * (inputs.ndim - len(padding_pairs_tup)) + padding_pairs_tup |
| 355 | + ) |
| 356 | + else: |
| 357 | + padding_lax = padding_pairs |
| 358 | + |
| 359 | + indices = jnp.arange(np.prod(inputs.shape[-num_spatial_dims:]), dtype=jnp.int64) |
| 360 | + indices = indices.reshape(inputs.shape[-num_spatial_dims:]) |
| 361 | + indices_shape = (1,) * (inputs.ndim - indices.ndim) + indices.shape |
| 362 | + indices = jnp.broadcast_to(indices.reshape(indices_shape), inputs.shape) |
| 363 | + |
| 364 | + return_dtype = inputs.dtype |
| 365 | + if jnp.issubdtype(inputs.dtype, jnp.integer): |
| 366 | + init_val = jnp.int32(jnp.iinfo(jnp.int32).min) |
| 367 | + inputs = inputs.astype(jnp.int32) |
| 368 | + else: |
| 369 | + init_val = jnp.float32(-jnp.inf) |
| 370 | + inputs = inputs.astype(jnp.float32) |
| 371 | + |
| 372 | + if not with_index: |
| 373 | + y = jax.lax.reduce_window( |
| 374 | + inputs, |
| 375 | + init_val, |
| 376 | + jax.lax.max, |
| 377 | + reduce_window_dims, |
| 378 | + reduce_window_strides, |
| 379 | + padding_lax, |
| 380 | + window_dilation=reduce_window_dilation, |
| 381 | + ) |
| 382 | + if is_single_input: |
| 383 | + y = jnp.squeeze(y, axis=0) |
| 384 | + return y.astype(return_dtype) |
| 385 | + else: |
| 386 | + |
| 387 | + def reduce_fn(a, b): |
| 388 | + ai, av = a |
| 389 | + bi, bv = b |
| 390 | + which = av >= bv |
| 391 | + return jnp.where(which, ai, bi), jnp.where(which, av, bv) |
| 392 | + |
| 393 | + indices, y = jax.lax.reduce_window( |
| 394 | + (indices, inputs), |
| 395 | + (jnp.int64(0), init_val), |
| 396 | + reduce_fn, |
| 397 | + reduce_window_dims, |
| 398 | + reduce_window_strides, |
| 399 | + padding_lax, |
| 400 | + window_dilation=reduce_window_dilation, |
| 401 | + ) |
| 402 | + if is_single_input: |
| 403 | + indices = jnp.squeeze(indices, axis=0) |
| 404 | + y = jnp.squeeze(y, axis=0) |
| 405 | + y = y.astype(return_dtype) |
| 406 | + return y, indices |
| 407 | + |
| 408 | + |
| 409 | +@lower_by_jax(torch.ops.aten.max_pool2d_with_indices) |
| 410 | +def _aten_max_pool2d_with_indices( |
| 411 | + self, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False |
| 412 | +): |
| 413 | + stride = stride if stride is not None else [] |
| 414 | + y = max_pool( |
| 415 | + self, |
| 416 | + kernel_size, |
| 417 | + strides=stride, |
| 418 | + padding=padding, |
| 419 | + dilation=dilation, |
| 420 | + ceil_mode=ceil_mode, |
| 421 | + with_index=False, |
| 422 | + ) |
| 423 | + # TFLite's reduce_window kernel doesn't support multiple inputs/outputs, |
| 424 | + # so we emit reduce_window with a single output and return dummy indices. |
| 425 | + return y, jnp.zeros_like(y, dtype=jnp.int64) |
| 426 | + |
| 427 | + |
| 428 | +@lower_by_jax(torch.ops.aten.max_pool3d_with_indices.default) |
| 429 | +def _aten_max_pool3d_with_indices( |
| 430 | + self, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False |
| 431 | +): |
| 432 | + stride = stride if stride is not None else [] |
| 433 | + y = max_pool( |
| 434 | + self, |
| 435 | + kernel_size, |
| 436 | + strides=stride, |
| 437 | + padding=padding, |
| 438 | + dilation=dilation, |
| 439 | + ceil_mode=ceil_mode, |
| 440 | + with_index=False, |
| 441 | + ) |
| 442 | + # TFLite's reduce_window kernel doesn't support multiple inputs/outputs, |
| 443 | + # so we emit reduce_window with a single output and return dummy indices. |
| 444 | + return y, jnp.zeros_like(y, dtype=jnp.int64) |
| 445 | + |
| 446 | + |
| 447 | +@lower_by_jax(torch.ops.aten.pixel_shuffle) |
| 448 | +def _aten_pixel_shuffle(x, upscale_factor): |
| 449 | + """PixelShuffle implementation in JAX lowering. |
| 450 | +
|
| 451 | + Args: |
| 452 | + x: Input tensor. Typically a feature map. |
| 453 | + upscale_factor: Integer by which to upscale the spatial dimensions. |
| 454 | +
|
| 455 | + Returns: |
| 456 | + Tensor after PixelShuffle operation. |
| 457 | + """ |
| 458 | + |
| 459 | + batch_size, channels, height, width = x.shape |
| 460 | + |
| 461 | + if channels % (upscale_factor**2) != 0: |
| 462 | + raise ValueError( |
| 463 | + "Number of channels must be divisible by the square of the upscale" |
| 464 | + " factor." |
| 465 | + ) |
| 466 | + |
| 467 | + new_channels = channels // (upscale_factor**2) |
| 468 | + new_height = height * upscale_factor |
| 469 | + new_width = width * upscale_factor |
| 470 | + |
| 471 | + x = x.reshape( |
| 472 | + batch_size, new_channels, upscale_factor, upscale_factor, height, width |
| 473 | + ) |
| 474 | + x = jnp.transpose( |
| 475 | + x, (0, 1, 4, 2, 5, 3) |
| 476 | + ) # Move channels to spatial dimensions |
| 477 | + x = x.reshape(batch_size, new_channels, new_height, new_width) |
| 478 | + |
| 479 | + return x |
| 480 | + |
| 481 | + |
243 | 482 | @lower_by_jax(torch.ops.aten.unbind) |
244 | 483 | def _aten_copy(self, *args, **kwargs): |
245 | 484 | return _TORCH_XLA2_IMPLS[torch.ops.aten.unbind_copy](self, *args, **kwargs) |
|
0 commit comments