|
2 | 2 | rng = StableRNG(12345) |
3 | 3 |
|
4 | 4 | @testset "$mode" for (mode, aType, ongpu) in MODES |
5 | | - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), |
| 5 | + @testset "$T: $x_shape" for T in (BFloat16, Float32, Float64), |
6 | 6 | x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) |
7 | 7 |
|
8 | 8 | x = randn(rng, T, x_shape) |> aType |
|
26 | 26 | __f = let rng = rng, T = T |
27 | 27 | x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) |
28 | 28 | end |
29 | | - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, |
30 | | - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), |
31 | | - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) |
| 29 | + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) |
32 | 30 |
|
33 | 31 | y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) |
34 | 32 |
|
|
48 | 46 | rng = StableRNG(12345) |
49 | 47 |
|
50 | 48 | @testset "$mode" for (mode, aType, ongpu) in MODES |
51 | | - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), |
| 49 | + @testset "$T: $x_shape" for T in (BFloat16, Float32, Float64), |
52 | 50 | x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) |
53 | 51 |
|
54 | 52 | x = randn(rng, T, x_shape) |> aType |
|
76 | 74 | x -> sum(first(dropout( |
77 | 75 | rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) |
78 | 76 | end |
79 | | - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, |
80 | | - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), |
81 | | - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) |
| 77 | + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) |
82 | 78 |
|
83 | 79 | @jet sum(first(dropout( |
84 | 80 | rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) |
|
106 | 102 | x -> sum(first(dropout( |
107 | 103 | rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) |
108 | 104 | end |
109 | | - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, |
110 | | - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), |
111 | | - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) |
| 105 | + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) |
112 | 106 |
|
113 | 107 | @jet sum(first(dropout( |
114 | 108 | rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) |
|
137 | 131 | x -> sum(first(dropout( |
138 | 132 | rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) |
139 | 133 | end |
140 | | - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, |
141 | | - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), |
142 | | - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) |
| 134 | + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) |
143 | 135 |
|
144 | 136 | @jet sum(first(dropout( |
145 | 137 | rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) |
|
165 | 157 | rng = StableRNG(12345) |
166 | 158 |
|
167 | 159 | @testset "$mode" for (mode, aType, ongpu) in MODES |
168 | | - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), |
| 160 | + @testset "$T: $x_shape" for T in (BFloat16, Float32, Float64), |
169 | 161 | x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) |
170 | 162 |
|
171 | 163 | x = randn(rng, T, x_shape) |> aType |
|
186 | 178 | __f = let rng = rng |
187 | 179 | x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) |
188 | 180 | end |
189 | | - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, |
190 | | - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), |
191 | | - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) |
| 181 | + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) |
192 | 182 |
|
193 | 183 | @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) |
194 | 184 | @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any |
|
0 commit comments