Skip to content

Commit a791dc1

Browse files
update
1 parent 7cef95a commit a791dc1

File tree

2 files changed

+3
-15
lines changed

2 files changed

+3
-15
lines changed

src/broadcast.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ broadcast_ops = [
3232
# "fdim",
3333
("invxback","invxback","(-xi*yi*yi)"),
3434
("reluback","reluback","(yi>0?xi:0)"),
35+
("eluback", "eluback", "ifelse(yi>0,dyi,yi+1)"),
3536
("sigmback","sigmback","(xi*yi*(1-yi))"),
3637
("tanhback","tanhback","(xi*(1-yi*yi))"),
3738
("rpow","rpow","pow(yi,xi)"), # need this for Array.^Scalar

src/unary.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ unary_ops = [
4545
# "normcdfinv",
4646
# "rcbrt",
4747
("relu", "relu", "(xi>0?xi:0)"),
48+
("elu", "elu", "(xi>0?xi:exp(xi)-1)"),
4849
# "rint",
4950
"round",
5051
# "rsqrt",
@@ -99,6 +100,7 @@ end
99100
for (f,g,y,dx) in
100101
((:invx, :invxback, :(one(T)/xi), :(-yi*yi*dyi)),
101102
(:relu, :reluback, :(max(zero(T),xi)), :(ifelse(yi>0,dyi,zero(T)))),
103+
(:elu, :eluback, :(ifelse(xi>0,xi,exp(xi)-1)), :(ifelse(yi>0,dyi,yi+1))),
102104
(:tanx, :tanhback, :(tanh(xi)), :(dyi*(one(T)-yi*yi))),
103105
(:sigm, :sigmback,
104106
# Numerically stable implementation from
@@ -153,21 +155,6 @@ broadcast(::typeof(+), a::KnetArray)=a
153155
+(a::KnetArray)=a
154156
-(a::KnetArray)=broadcast(-,a)
155157

156-
"""
157-
elu(x, alpha=1)
158-
159-
Exponential Linear Unit. Returns
160-
`max(0,x) + alpha*(exp(min(x,0)) - 1)
161-
162-
Paper Ref. :
163-
"Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) (ICLR 2016)"
164-
"""
165-
function elu(x, alpha=1)
166-
p = relu(x)
167-
m = -relu(-x)
168-
return p + alpha*(exp(m) - 1)
169-
end
170-
171158
"""
172159
selu(x)
173160

0 commit comments

Comments
 (0)