Skip to content

Commit 9ef7cdb

Browse files
committed
Improve readability
1 parent 44c3ab0 commit 9ef7cdb

File tree

4 files changed

+129
-109
lines changed

4 files changed

+129
-109
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.2.5"
66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1011
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1112

src/TaylorDiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ can_taylorize(::Type) = false
1414
" If the type behaves as a scalar, define TaylorDiff.can_taylorize(::Type{$V}) = true."))
1515
end
1616

17+
include("utils.jl")
1718
include("scalar.jl")
1819
include("array.jl")
1920
include("primitive.jl")
20-
include("utils.jl")
2121
include("codegen.jl")
2222
include("derivative.jl")
2323
include("chainrules.jl")

src/primitive.jl

Lines changed: 64 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@ import Base: abs, abs2
22
import Base: exp, exp2, exp10, expm1, log, log2, log10, log1p, inv, sqrt, cbrt
33
import Base: sin, cos, tan, cot, sec, csc, sinh, cosh, tanh, coth, sech, csch, sinpi, cospi
44
import Base: asin, acos, atan, acot, asec, acsc, asinh, acosh, atanh, acoth, asech, acsch
5-
import Base: sinc, cosc
65
import Base: +, -, *, /, \, ^, >, <, >=, <=, ==
7-
import Base: hypot, max, min
8-
import Base: tail
9-
import Base: convert, promote_rule
6+
import Base: sinc, cosc, hypot, max, min, literal_pow
107

118
Taylor = Union{TaylorScalar, TaylorArray}
129

@@ -22,7 +19,7 @@ Taylor = Union{TaylorScalar, TaylorArray}
2219

2320
@inline flatten(t::Taylor) = (value(t), partials(t)...)
2421

25-
function promote_rule(::Type{TaylorScalar{T, P}},
22+
function Base.promote_rule(::Type{TaylorScalar{T, P}},
2623
::Type{S}) where {T, S, P}
2724
TaylorScalar{promote_type(T, S), P}
2825
end
@@ -35,78 +32,49 @@ end
3532

3633
## Delegated
3734

38-
@inline +(t::TaylorScalar) = t
3935
@inline -(t::TaylorScalar) = TaylorScalar(-value(t), .-partials(t))
4036
@inline sqrt(t::TaylorScalar) = t^0.5
4137
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
4238
@inline inv(t::TaylorScalar) = one(t) / t
39+
@inline sinpi(t::TaylorScalar) = sin* t)
40+
@inline cospi(t::TaylorScalar) = cos* t)
41+
@inline exp10(t::TaylorScalar) = exp(t * log(10))
42+
@inline exp2(t::TaylorScalar) = exp(t * log(2))
43+
@inline expm1(t::TaylorScalar) = TaylorScalar(expm1(value(t)), partials(exp(t)))
4344

44-
for func in (:exp, :expm1, :exp2, :exp10)
45-
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
46-
v = [Symbol("v$i") for i in 0:P]
47-
ex = quote
48-
$(Expr(:meta, :inline))
49-
p = value(t)
50-
f = flatten(t)
51-
v0 = $($(QuoteNode(func)) == :expm1 ? :(exp(p)) : :($$func(p)))
52-
end
53-
for i in 1:P
54-
push!(ex.args,
55-
:(
56-
$(v[begin + i]) = +($([:($(i - j) * $(v[begin + j]) *
57-
f[begin + $(i - j)])
58-
for j in 0:(i - 1)]...)) / $i
59-
))
60-
if $(QuoteNode(func)) == :exp2
61-
push!(ex.args, :($(v[begin + i]) *= log(2)))
62-
elseif $(QuoteNode(func)) == :exp10
63-
push!(ex.args, :($(v[begin + i]) *= log(10)))
64-
end
65-
end
66-
if $(QuoteNode(func)) == :expm1
67-
push!(ex.args, :(v0 = expm1(f[1])))
45+
## Hand-written exp, sin, cos
46+
47+
@to_static function exp(t::TaylorScalar{T, P}) where {P, T}
48+
f = flatten(t)
49+
v[0] = exp(f[0])
50+
for i in 1:P
51+
v[i] = zero(T)
52+
for j in 0:(i - 1)
53+
v[i] += (i - j) * v[j] * f[i - j]
6854
end
69-
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
70-
return :(@inbounds $ex)
55+
v[i] /= i
7156
end
57+
return TaylorScalar(v)
7258
end
7359

7460
for func in (:sin, :cos)
75-
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
76-
s = [Symbol("s$i") for i in 0:P]
77-
c = [Symbol("c$i") for i in 0:P]
78-
ex = quote
79-
$(Expr(:meta, :inline))
80-
f = flatten(t)
81-
s0 = sin(f[1])
82-
c0 = cos(f[1])
83-
end
61+
@eval @to_static function $func(t::TaylorScalar{T, P}) where {T, P}
62+
f = flatten(t)
63+
s[0], c[0] = sincos(f[0])
8464
for i in 1:P
85-
push!(ex.args,
86-
:($(s[begin + i]) = +($([:(
87-
$(i - j) * $(c[begin + j]) *
88-
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
89-
$i)
90-
)
91-
push!(ex.args,
92-
:($(c[begin + i]) = +($([:(
93-
$(i - j) * $(s[begin + j]) *
94-
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
95-
-$i)
96-
)
97-
end
98-
if $(QuoteNode(func)) == :sin
99-
push!(ex.args, :(TaylorScalar(tuple($(s...)))))
100-
else
101-
push!(ex.args, :(TaylorScalar(tuple($(c...)))))
65+
s[i] = zero(T)
66+
c[i] = zero(T)
67+
for j in 0:(i - 1)
68+
s[i] += (i - j) * c[j] * f[i - j]
69+
c[i] -= (i - j) * s[j] * f[i - j]
70+
end
71+
s[i] /= i
72+
c[i] /= i
10273
end
103-
return :(@inbounds $ex)
74+
return $(func == :sin ? :(TaylorScalar(s)) : :(TaylorScalar(c)))
10475
end
10576
end
10677

107-
@inline sinpi(t::TaylorScalar) = sin* t)
108-
@inline cospi(t::TaylorScalar) = cos* t)
109-
11078
# Binary
11179

11280
## Easy case
@@ -136,63 +104,51 @@ end
136104
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
137105
value(a) - value(b), map(-, partials(a), partials(b)))
138106

139-
@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
140-
return quote
141-
$(Expr(:meta, :inline))
142-
va, vb = flatten(a), flatten(b)
143-
v = tuple($([:(
144-
+($([:(va[begin + $j] * vb[begin + $(i - j)]) for j in 0:i]...))
145-
) for i in 0:N]...))
146-
@inbounds TaylorScalar(v)
107+
@to_static function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
108+
va, vb = flatten(a), flatten(b)
109+
for i in 0:P
110+
v[i] = zero(T)
111+
for j in 0:i
112+
v[i] += va[j] * vb[i - j]
113+
end
147114
end
115+
TaylorScalar(v)
148116
end
149117

150-
@generated function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
151-
v = [Symbol("v$i") for i in 0:P]
152-
ex = quote
153-
$(Expr(:meta, :inline))
154-
va, vb = flatten(a), flatten(b)
155-
v0 = va[1] / vb[1]
156-
b0 = vb[1]
157-
end
118+
@to_static function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
119+
va, vb = flatten(a), flatten(b)
120+
v[0] = va[0] / vb[0]
158121
for i in 1:P
159-
push!(ex.args,
160-
:(
161-
$(v[begin + i]) = (va[begin + $i] -
162-
+($([:($(v[begin + j]) *
163-
vb[begin + $(i - j)])
164-
for j in 0:(i - 1)]...))) / b0
165-
)
166-
)
122+
v[i] = va[i]
123+
for j in 0:(i - 1)
124+
v[i] -= vb[i - j] * v[j]
125+
end
126+
v[i] /= vb[0]
167127
end
168-
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
169-
return :(@inbounds $ex)
128+
TaylorScalar(v)
170129
end
171130

131+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{0}) = one(x)
132+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{1}) = x
133+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{2}) = x*x
134+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{3}) = x*x*x
135+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-1}) = inv(x)
136+
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i=inv(x); i*i)
137+
172138
for R in (Integer, Real)
173-
@eval @generated function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
174-
v = [Symbol("v$i") for i in 0:P]
175-
ex = quote
176-
$(Expr(:meta, :inline))
177-
f = flatten(t)
178-
f0 = f[1]
179-
v0 = ^(f0, n)
180-
end
139+
@eval @to_static function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
140+
f = flatten(t)
141+
v[0] = f[0]^n
181142
for i in 1:P
182-
push!(ex.args,
183-
:(
184-
$(v[begin + i]) = +($([:(
185-
(n * $(i - j) - $j) * $(v[begin + j]) *
186-
f[begin + $(i - j)]
187-
) for j in 0:(i - 1)]...)) / ($i * f0)
188-
))
143+
v[i] = zero(T)
144+
for j in 0:(i - 1)
145+
v[i] += (n * (i - j) - j) * v[j] * f[i - j]
146+
end
147+
v[i] /= (i * f[0])
189148
end
190-
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
191-
return :(@inbounds $ex)
192-
end
193-
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
194-
exp(t * log(a))
149+
return TaylorScalar(v)
195150
end
151+
@eval ^(a::S, t::TaylorScalar) where {S <: $R} = exp(t * log(a))
196152
end
197153

198154
^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))

src/utils.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using ChainRules
22
using ChainRulesCore
33
using Symbolics: @variables, @rule, unwrap, isdiv
44
using SymbolicUtils.Code: toexpr
5+
using MacroTools
6+
using MacroTools: prewalk, postwalk
57

68
"""
79
Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule.
@@ -45,3 +47,64 @@ function define_unary_function(func, m)
4547
end
4648
end
4749
end
50+
51+
tuplen(::Type{NTuple{N, T}}) where {N, T} = N
52+
function interpolate(ex::Expr, dict)
53+
func = ex.args[1]
54+
args = map(x -> interpolate(x, dict), ex.args[2:end])
55+
getproperty(Base, func)(args...)
56+
end
57+
interpolate(ex::Symbol, dict) = get(dict, ex, ex)
58+
interpolate(ex::Any, _) = ex
59+
60+
function unroll_loop(start, stop, var, body, d)
61+
ex = Expr(:block)
62+
start = interpolate(start, d)
63+
stop = interpolate(stop, d)
64+
for i in start:stop
65+
iter = prewalk(x -> x === var ? i : x, body)
66+
args = filter(x -> !(x isa LineNumberNode), iter.args)
67+
append!(ex.args, args)
68+
end
69+
ex
70+
end
71+
72+
function process(d, expr)
73+
# Unroll loops
74+
expr = prewalk(expr) do x
75+
@match x begin
76+
for var_ in start_:stop_
77+
body_
78+
end => unroll_loop(start, stop, var, body, d)
79+
_ => x
80+
end
81+
end
82+
# Modify indices
83+
magic_names = (:v, :s, :c)
84+
expr = postwalk(expr) do x
85+
@match x begin
86+
a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx])
87+
TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...))))
88+
_ => x
89+
end
90+
end
91+
# Add inline meta
92+
return quote
93+
$(Expr(:meta, :inline))
94+
$expr
95+
end
96+
end
97+
98+
macro to_static(def)
99+
dict = splitdef(def)
100+
pairs = Any[]
101+
for symbol in dict[:whereparams]
102+
push!(pairs, :($(QuoteNode(symbol)) => $symbol))
103+
end
104+
esc(quote
105+
@generated function $(dict[:name])($(dict[:args]...)) where {$(dict[:whereparams]...)}
106+
d = Dict($(pairs...))
107+
process(d, $(QuoteNode(dict[:body])))
108+
end
109+
end)
110+
end

0 commit comments

Comments
 (0)