Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 4c1daf3

Browse files
committed
fix: dispatch to loopvec for groupnorm
1 parent 685f42a commit 4c1daf3

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

src/impl/affine_normalize.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,11 @@ function __affine_normalize_gn_impl!(opmode::LoopedArrayOp, y::AbstractArray{<:N
384384
f::F, x::AbstractArray{<:Number, 4}, μ, σ²,
385385
scale::Optional{<:AbstractArray{<:Number, 4}},
386386
bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F}
387-
__affine_normalize_gn_impl!(opmode, y, nothing, x, μ, σ², scale, bias, ϵ)
387+
__affine_normalize_gn_impl_loopvec!(opmode, y, x, μ, σ², scale, bias, ϵ)
388388
_fast_activation!(f, y) # NOTE: don't fuse into the above loop
389389
end
390390

391-
function __affine_normalize_gn_impl!(
392-
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing,
391+
function __affine_normalize_gn_impl_loopvec!(::LoopedArrayOp, y::AbstractArray{<:Number, 4},
393392
x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real)
394393
@tturbo for L in indices(y, 4), K in indices(y, 3)
395394
_sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@@ -400,10 +399,9 @@ function __affine_normalize_gn_impl!(
400399
end
401400
end
402401

403-
function __affine_normalize_gn_impl!(
404-
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing,
405-
x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4},
406-
bias::AbstractArray{<:Number, 4}, ϵ::Real)
402+
function __affine_normalize_gn_impl_loopvec!(
403+
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ,
404+
σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real)
407405
@tturbo for L in indices(y, 4), K in indices(y, 3)
408406
idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ))
409407
for J in indices(y, 2)
@@ -417,7 +415,7 @@ function __affine_normalize_gn_impl!(
417415
end
418416

419417
@inbounds function __affine_normalize_gn_impl_no_turbo!(
420-
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing,
418+
::LoopedArrayOp, y::AbstractArray{<:Number, 4},
421419
x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real)
422420
for L in indices(y, 4), K in indices(y, 3)
423421
_sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@@ -431,9 +429,8 @@ end
431429
end
432430

433431
@inbounds function __affine_normalize_gn_impl_no_turbo!(
434-
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, ::Nothing,
435-
x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4},
436-
bias::AbstractArray{<:Number, 4}, ϵ::Real)
432+
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, x::AbstractArray{<:Number, 4}, μ,
433+
σ², scale::AbstractArray{<:Number, 4}, bias::AbstractArray{<:Number, 4}, ϵ::Real)
437434
for L in indices(y, 4), K in indices(y, 3)
438435
idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ))
439436
for J in indices(y, 2)

0 commit comments

Comments
 (0)