Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit ed02845

Browse files
committed
fix: wrong function in macro
1 parent 26f4889 commit ed02845

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/impl/affine_normalize.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,13 @@ function __affine_normalize_gn_impl!(opmode::LoopedArrayOp, y::AbstractArray{<:N
384384
f::F, x::AbstractArray{<:Number, 4}, μ, σ²,
385385
scale::Optional{<:AbstractArray{<:Number, 4}},
386386
bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F}
387-
__affine_normalize_gn_impl_loopvec!(opmode, y, x, μ, σ², scale, bias, ϵ)
387+
__affine_normalize_gn_impl_loopvec!(y, x, μ, σ², scale, bias, ϵ)
388388
_fast_activation!(f, y) # NOTE: don't fuse into the above loop
389389
end
390390

391-
function __affine_normalize_gn_impl_loopvec!(::LoopedArrayOp, y::AbstractArray{<:Number, 4},
392-
x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real)
391+
function __affine_normalize_gn_impl_loopvec!(
392+
y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4},
393+
μ, σ², ::Nothing, ::Nothing, ϵ::Real)
393394
@tturbo for L in indices(y, 4), K in indices(y, 3)
394395
_sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
395396
_bc = -μ[1, 1, K, L] * _sc
@@ -400,8 +401,8 @@ function __affine_normalize_gn_impl_loopvec!(::LoopedArrayOp, y::AbstractArray{<
400401
end
401402

402403
function __affine_normalize_gn_impl_loopvec!(
403-
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ,
404-
σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real)
404+
y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ²,
405+
scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real)
405406
@tturbo for L in indices(y, 4), K in indices(y, 3)
406407
idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ))
407408
for J in indices(y, 2)
@@ -415,8 +416,8 @@ function __affine_normalize_gn_impl_loopvec!(
415416
end
416417

417418
@inbounds function __affine_normalize_gn_impl_no_turbo!(
418-
::LoopedArrayOp, y::AbstractArray{<:Number, 4},
419-
x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real)
419+
y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4},
420+
μ, σ², ::Nothing, ::Nothing, ϵ::Real)
420421
for L in indices(y, 4), K in indices(y, 3)
421422
_sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
422423
_bc = -μ[1, 1, K, L] * _sc
@@ -429,8 +430,8 @@ end
429430
end
430431

431432
@inbounds function __affine_normalize_gn_impl_no_turbo!(
432-
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ,
433-
σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real)
433+
y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ, σ²,
434+
scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real)
434435
for L in indices(y, 4), K in indices(y, 3)
435436
idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ))
436437
for J in indices(y, 2)
@@ -443,7 +444,7 @@ end
443444
end
444445
end
445446

446-
@enzyme_reverse_alternative __affine_normalize_gn_impl! __affine_normalize_gn_impl_no_turbo!
447+
@enzyme_reverse_alternative __affine_normalize_gn_impl_loopvec! __affine_normalize_gn_impl_no_turbo!
447448

448449
function __affine_normalize_gn_impl!(::GPUBroadcastOp, y::AbstractArray{<:Number, 4}, f::F,
449450
x::AbstractArray{<:Number, 4}, μ, σ², scale::Optional{<:AbstractArray},

0 commit comments

Comments
 (0)