Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 2b4fd85

Browse files
committed
test: swap Float16 tests with BFloat16
1 parent e31aa74 commit 2b4fd85

File tree

10 files changed

+51
-86
lines changed

10 files changed

+51
-86
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"]
4444
AMDGPU = "0.9.6"
4545
Aqua = "0.8.7"
4646
ArrayInterface = "7.9"
47+
BFloat16s = "0.5.0"
4748
CUDA = "5.3.2"
4849
ChainRulesCore = "1.24"
4950
ComponentArrays = "0.15.16"
@@ -86,6 +87,7 @@ julia = "1.10"
8687

8788
[extras]
8889
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
90+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
8991
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
9092
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
9193
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
@@ -104,4 +106,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
104106
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
105107

106108
[targets]
107-
test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"]
109+
test = ["Aqua", "BFloat16s", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"]

test/common_ops/activation_tests.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88
@testset "$mode" for (mode, aType, ongpu) in MODES
99
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
1010
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
11-
T in [Float16, Float32, Float64]
11+
T in [BFloat16, Float32, Float64]
1212

1313
x = rand(rng, T, 4, 3) |> aType
1414

1515
y1 = apply_act(f, x)
1616
y2 = apply_act_fast(f, x)
1717
y3 = apply_act_fast2(f, x)
1818

19-
fp16 = T == Float16
20-
atol = fp16 ? 1.0f-1 : 1.0f-3
21-
rtol = fp16 ? 1.0f-1 : 1.0f-3
19+
atol = 1.0f-3
20+
rtol = 1.0f-3
2221

2322
@test y1y2 atol=atol rtol=rtol
2423
@test y1y3 atol=atol rtol=rtol

test/common_ops/conv_tests.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
2828

2929
y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims)
3030

31-
fp16 = Tx == Float16 || Tw == Float16
32-
atol = fp16 ? 1.0f-1 : 1.0f-3
33-
rtol = fp16 ? 1.0f-1 : 1.0f-3
31+
atol = 1.0f-3
32+
rtol = 1.0f-3
3433
# Operation reordering has an effect on the accuracy of the results
3534
@test yy_generic atol=atol rtol=rtol
3635
@test eltype(y) == promote_type(Tw, Tx)
@@ -61,14 +60,13 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
6160
mp && push!(skip_backends, AutoReverseDiff())
6261
((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) &&
6362
push!(skip_backends, AutoTracker())
64-
test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends,
65-
soft_fail=(fp16 ? [AutoFiniteDiff()] : []))
63+
test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends)
6664
end
6765

6866
anonact = x -> gelu(x)
6967

70-
const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32),
71-
(Float32, Float64), (Float64, Float64)]
68+
const ELTYPES = [(BFloat16, BFloat16), (Float32, BFloat16),
69+
(Float32, Float32), (Float32, Float64), (Float64, Float64)]
7270
const ACTIVATIONS = [
7371
identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact]
7472

test/common_ops/dense_tests.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,21 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode
2525
@test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true
2626
end
2727

28-
fp16 = Tx == Float16 || Tw == Float16
29-
atol = fp16 ? 1.0f-1 : 1.0f-3
30-
rtol = fp16 ? 1.0f-1 : 1.0f-3
28+
atol = 1.0f-3
29+
rtol = 1.0f-3
3130

3231
skip_backends = []
3332
Tw != Tx && push!(skip_backends, AutoReverseDiff())
34-
fp16 && push!(skip_backends, AutoFiniteDiff())
3533

3634
__f_grad = let activation = activation
3735
(w, x, b) -> __f(activation, w, x, b)
3836
end
39-
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends,
40-
soft_fail=(fp16 ? [AutoFiniteDiff()] : []))
37+
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends)
4138
end
4239

4340
const ALL_TEST_CONFIGS = Iterators.product(
44-
((Float16, Float16), (Float32, Float16), (Float32, Float32),
45-
(Float32, Float64), (Float64, Float64)),
41+
((BFloat16, BFloat16), (Float32, BFloat16),
42+
(Float32, Float32), (Float32, Float64), (Float64, Float64)),
4643
(4, 8),
4744
(4, 8),
4845
(true, false),

test/common_ops/dropout_tests.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
rng = StableRNG(12345)
33

44
@testset "$mode" for (mode, aType, ongpu) in MODES
5-
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
5+
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
66
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))
77

88
x = randn(rng, T, x_shape) |> aType
@@ -26,9 +26,7 @@
2626
__f = let rng = rng, T = T
2727
x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon())))
2828
end
29-
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
30-
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
31-
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
29+
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)
3230

3331
y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon())
3432

@@ -48,7 +46,7 @@ end
4846
rng = StableRNG(12345)
4947

5048
@testset "$mode" for (mode, aType, ongpu) in MODES
51-
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
49+
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
5250
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))
5351

5452
x = randn(rng, T, x_shape) |> aType
@@ -76,9 +74,7 @@ end
7674
x -> sum(first(dropout(
7775
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
7876
end
79-
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
80-
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
81-
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
77+
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)
8278

8379
@jet sum(first(dropout(
8480
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
@@ -106,9 +102,7 @@ end
106102
x -> sum(first(dropout(
107103
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
108104
end
109-
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
110-
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
111-
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
105+
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)
112106

113107
@jet sum(first(dropout(
114108
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
@@ -137,9 +131,7 @@ end
137131
x -> sum(first(dropout(
138132
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
139133
end
140-
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
141-
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
142-
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
134+
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)
143135

144136
@jet sum(first(dropout(
145137
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
@@ -165,7 +157,7 @@ end
165157
rng = StableRNG(12345)
166158

167159
@testset "$mode" for (mode, aType, ongpu) in MODES
168-
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
160+
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
169161
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))
170162

171163
x = randn(rng, T, x_shape) |> aType
@@ -186,9 +178,7 @@ end
186178
__f = let rng = rng
187179
x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
188180
end
189-
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
190-
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
191-
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
181+
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)
192182

193183
@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
194184
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any

test/normalization/batchnorm_tests.jl

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ function run_batchnorm_testing(
4141
y_simple, nt_simple = __batchnorm_basic(
4242
x, scale, bias, rm, rv, training, act, T(0.9), epsilon)
4343

44-
fp16 = T == Float16
45-
atol = fp16 ? 1.0f-2 : 1.0f-3
46-
rtol = fp16 ? 1.0f-2 : 1.0f-3
44+
atol = 1.0f-3
45+
rtol = 1.0f-3
4746

4847
@test yy_simple atol=atol rtol=rtol
4948
if track_stats
@@ -82,22 +81,9 @@ function run_batchnorm_testing(
8281
skip_backends = []
8382
act === relu && push!(skip_backends, AutoFiniteDiff())
8483

85-
soft_fail = if fp16
86-
if Sys.iswindows()
87-
[AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()]
88-
else
89-
true
90-
end
91-
else
92-
false
93-
end
94-
95-
broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : []
96-
9784
__f = (args...) -> sum(first(batchnorm(
9885
args..., rm, rv, training, act, T(0.9), epsilon)))
99-
test_gradients(
100-
__f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends)
86+
test_gradients(__f, x, scale, bias; atol, rtol, skip_backends)
10187
end
10288

10389
if anonact !== act
@@ -109,7 +95,7 @@ function run_batchnorm_testing(
10995
end
11096

11197
const ALL_TEST_CONFIGS = Iterators.product(
112-
[Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
98+
[BFloat16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
11399
(Val(true), Val(false)), (true, false), (true, false),
114100
(identity, relu, tanh_fast, sigmoid_fast, anonact))
115101

test/normalization/groupnorm_tests.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,17 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu)
3434

3535
y_simple = _f2(x, scale, bias)
3636

37-
fp16 = T == Float16
38-
atol = fp16 ? 1.0f-2 : 1.0f-3
39-
rtol = fp16 ? 1.0f-2 : 1.0f-3
37+
atol = 1.0f-3
38+
rtol = 1.0f-3
4039

4140
@test yy_simple atol=atol rtol=rtol
4241

4342
# Check the rrules
44-
if !fp16
45-
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
46-
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f2, x, scale, bias)
47-
@test ∂x∂x_simple atol=atol rtol=rtol
48-
@test ∂scale∂scale_simple atol=atol rtol=rtol
49-
@test ∂bias∂bias_simple atol=atol rtol=rtol
50-
end
43+
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
44+
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f2, x, scale, bias)
45+
@test ∂x∂x_simple atol=atol rtol=rtol
46+
@test ∂scale∂scale_simple atol=atol rtol=rtol
47+
@test ∂bias∂bias_simple atol=atol rtol=rtol
5148

5249
@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
5350
@jet groupnorm(x, scale, bias, groups, act, epsilon)
@@ -61,11 +58,11 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu)
6158
@test size(y) == sz
6259

6360
__f = (args...) -> sum(groupnorm(args..., groups, act, epsilon))
64-
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
61+
soft_fail = [AutoFiniteDiff()]
6562
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
6663
end
6764

68-
const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64],
65+
const ALL_TEST_CONFIGS = Iterators.product([BFloat16, Float32, Float64],
6966
((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2),
7067
(4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)),
7168
(2, 3),

test/normalization/instancenorm_tests.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,17 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp
2121

2222
y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon)
2323

24-
fp16 = T == Float16
25-
atol = fp16 ? 1.0f-2 : 1.0f-3
26-
rtol = fp16 ? 1.0f-2 : 1.0f-3
24+
atol = 1.0f-3
25+
rtol = 1.0f-3
2726

2827
@test yy_simple atol=atol rtol=rtol
2928

3029
# Check the rrules
31-
if !fp16
32-
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
33-
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f, x, scale, bias)
34-
@test ∂x∂x_simple atol=atol rtol=rtol
35-
@test ∂scale∂scale_simple atol=atol rtol=rtol
36-
@test ∂bias∂bias_simple atol=atol rtol=rtol
37-
end
30+
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
31+
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f, x, scale, bias)
32+
@test ∂x∂x_simple atol=atol rtol=rtol
33+
@test ∂scale∂scale_simple atol=atol rtol=rtol
34+
@test ∂bias∂bias_simple atol=atol rtol=rtol
3835

3936
@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
4037
@jet instancenorm(x, scale, bias, training, act, epsilon)
@@ -49,13 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp
4946

5047
if __is_training(training)
5148
__f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon)))
52-
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
49+
soft_fail = [AutoFiniteDiff()]
5350
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
5451
end
5552
end
5653

5754
const ALL_TEST_CONFIGS = Iterators.product(
58-
[Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
55+
[BFloat16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
5956
(Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact))
6057

6158
const TEST_BLOCKS = collect(Iterators.partition(

test/normalization/layernorm_tests.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,10 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu
3333
@test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1)
3434
end
3535

36-
fp16 = T == Float16
37-
atol = fp16 ? 1.0f-2 : 1.0f-3
38-
rtol = fp16 ? 1.0f-2 : 1.0f-3
36+
atol = 1.0f-3
37+
rtol = 1.0f-3
3938

40-
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
39+
soft_fail = [AutoFiniteDiff()]
4140
if affine_shape !== nothing
4241
__f = (args...) -> sum(_f(args...))
4342
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
@@ -56,7 +55,7 @@ anonact = x -> x^3
5655

5756
const ALL_TEST_CONFIGS = Any[]
5857

59-
for T in (Float16, Float32, Float64),
58+
for T in (BFloat16, Float32, Float64),
6059
x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)),
6160
affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])),
6261
act in (identity, relu, tanh_fast, sigmoid_fast, anonact)

test/shared_testsetup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import Reexport: @reexport
33

44
using LuxLib, MLDataDevices
5-
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote
5+
@reexport using BFloat16s, LuxTestUtils, StableRNGs, Test, Enzyme, Zygote
66

77
LuxTestUtils.jet_target_modules!(["LuxLib"])
88

0 commit comments

Comments
 (0)