Skip to content

Commit 4d451c0

Browse files
committed
coeffs instead of derivatives
1 parent 2b0b1ce commit 4d451c0

File tree

4 files changed

+132
-170
lines changed

4 files changed

+132
-170
lines changed

src/chainrules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P
3333
return partials(t), value_pullback
3434
end
3535

36-
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
37-
i::Integer) where {N, T}
36+
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, P},
37+
q::Val{Q}) where {T, P, Q}
3838
function extract_derivative_pullback(d̄)
39-
NoTangent(), TaylorScalar(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
39+
NoTangent(), TaylorScalar(zero(T), ntuple(j -> j === Q ?* factorial(Q) : zero(T), Val(P))),
4040
NoTangent()
4141
end
42-
return extract_derivative(t, i), extract_derivative_pullback
42+
return extract_derivative(t, q), extract_derivative_pullback
4343
end
4444

4545
function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)

src/derivative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ function derivatives end
4747

4848
# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
4949
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
50-
derivatives(f, x, l, p), P)
50+
derivatives(f, x, l, p), p)
5151
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
52-
derivatives(f!, y, x, l, p), P)
52+
derivatives(f!, y, x, l, p), p)
5353
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
54-
result, derivatives(f, x, l, p), P)
54+
result, derivatives(f, x, l, p), p)
5555
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
56-
result, derivatives(f!, y, x, l, p), P)
56+
result, derivatives(f!, y, x, l, p), p)
5757

5858
# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1`
5959

src/primitive.jl

Lines changed: 99 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ Taylor = Union{TaylorScalar, TaylorArray}
1212

1313
@inline value(t::Taylor) = t.value
1414
@inline partials(t::Taylor) = t.partials
15-
@inline extract_derivative(t::Taylor, i::Integer) = t.partials[i]
16-
@inline extract_derivative(v::AbstractArray{<:TaylorScalar}, i::Integer) = map(
17-
t -> extract_derivative(t, i), v)
18-
@inline extract_derivative(r, i::Integer) = false
19-
@inline function extract_derivative!(result::AbstractArray, v::AbstractArray{T},
20-
i::Integer) where {T <: TaylorScalar}
21-
map!(t -> extract_derivative(t, i), result, v)
22-
end
15+
@inline extract_derivative(t::Taylor, ::Val{P}) where {P} = t.partials[P] * factorial(P)
16+
@inline extract_derivative(a::AbstractArray{<:TaylorScalar}, p) = map(
17+
t -> extract_derivative(t, p), a)
18+
@inline extract_derivative(_, p) = false
19+
@inline extract_derivative!(result, a::AbstractArray{<:TaylorScalar}, p) = map!(
20+
t -> extract_derivative(t, p), result, a)
2321

2422
@inline flatten(t::Taylor) = (value(t), partials(t)...)
2523

@@ -33,15 +31,6 @@ function (::Type{F})(x::TaylorScalar{T, P}) where {T, P, F <: AbstractFloat}
3331
end
3432

3533
# Unary
36-
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
37-
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), map(-, partials(b)))
38-
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
39-
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)
40-
41-
@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
42-
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
43-
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
44-
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
4534

4635
## Delegated
4736

@@ -50,57 +39,63 @@ end
5039
@inline inv(t::TaylorScalar) = one(t) / t
5140

5241
for func in (:exp, :expm1, :exp2, :exp10)
53-
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
42+
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
43+
v = [Symbol("v$i") for i in 0:P]
5444
ex = quote
55-
v = flatten(t)
56-
v1 = $($(QuoteNode(func)) == :expm1 ? :(exp(v[1])) : :($$func(v[1])))
45+
p = value(t)
46+
f = flatten(t)
47+
v0 = $($(QuoteNode(func)) == :expm1 ? :(exp(p)) : :($$func(p)))
5748
end
58-
for i in 2:(N + 1)
59-
ex = quote
60-
$ex
61-
$(Symbol('v', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('v', j)) *
62-
v[$(i + 1 - j)])
63-
for j in 1:(i - 1)]...))
64-
end
49+
for i in 1:P
50+
push!(ex.args,
51+
:(
52+
$(v[begin + i]) = +($([:($(i - j) * $(v[begin + j]) * f[begin + $(i - j)])
53+
for j in 0:(i - 1)]...)) / $i
54+
))
6555
if $(QuoteNode(func)) == :exp2
66-
ex = :($ex; $(Symbol('v', i)) *= $(log(2)))
56+
push!(ex.args, :($(v[begin + i]) *= log(2)))
6757
elseif $(QuoteNode(func)) == :exp10
68-
ex = :($ex; $(Symbol('v', i)) *= $(log(10)))
58+
push!(ex.args, :($(v[begin + i]) *= log(10)))
6959
end
7060
end
7161
if $(QuoteNode(func)) == :expm1
72-
ex = :($ex; v1 = expm1(v[1]))
62+
push!(ex.args, :(v0 = expm1(f[1])))
7363
end
74-
ex = :($ex; TaylorScalar(tuple($([Symbol('v', i) for i in 1:(N + 1)]...))))
64+
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
7565
return :(@inbounds $ex)
7666
end
7767
end
7868

7969
for func in (:sin, :cos)
80-
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
70+
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
71+
s = [Symbol("s$i") for i in 0:P]
72+
c = [Symbol("c$i") for i in 0:P]
8173
ex = quote
82-
v = flatten(t)
83-
s1 = sin(v[1])
84-
c1 = cos(v[1])
74+
$(Expr(:meta, :inline))
75+
f = flatten(t)
76+
s0 = sin(f[1])
77+
c0 = cos(f[1])
8578
end
86-
for i in 2:(N + 1)
87-
ex = :($ex;
88-
$(Symbol('s', i)) = +($([:($(binomial(i - 2, j - 1)) *
89-
$(Symbol('c', j)) *
90-
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
91-
ex = :($ex;
92-
$(Symbol('c', i)) = +($([:($(-binomial(i - 2, j - 1)) *
93-
$(Symbol('s', j)) *
94-
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
79+
for i in 1:P
80+
push!(ex.args,
81+
:($(s[begin + i]) = +($([:(
82+
$(i - j) * $(c[begin + j]) *
83+
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
84+
$i)
85+
)
86+
push!(ex.args,
87+
:($(c[begin + i]) = +($([:(
88+
$(i - j) * $(s[begin + j]) *
89+
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
90+
-$i)
91+
)
9592
end
9693
if $(QuoteNode(func)) == :sin
97-
ex = :($ex; TaylorScalar(tuple($([Symbol('s', i) for i in 1:(N + 1)]...))))
94+
push!(ex.args, :(TaylorScalar(tuple($(s...)))))
9895
else
99-
ex = :($ex; TaylorScalar(tuple($([Symbol('c', i) for i in 1:(N + 1)]...))))
100-
end
101-
return quote
102-
@inbounds $ex
96+
push!(ex.args, :(TaylorScalar(tuple($(c...)))))
10397
end
98+
return :(@inbounds $ex)
10499
end
105100
end
106101

@@ -109,6 +104,18 @@ end
109104

110105
# Binary
111106

107+
## Easy case
108+
109+
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
110+
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), .-partials(b))
111+
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
112+
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)
113+
114+
@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
115+
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
116+
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
117+
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
118+
112119
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)
113120

114121
for op in [:>, :<, :(==), :(>=), :(<=)]
@@ -126,75 +133,55 @@ end
126133

127134
@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
128135
return quote
136+
$(Expr(:meta, :inline))
129137
va, vb = flatten(a), flatten(b)
130-
r = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
131-
vb[$(i + 1 - j)]) for j in 1:i]...)))
132-
for i in 1:(N + 1)]...))
133-
@inbounds TaylorScalar(r[1], r[2:end])
138+
v = tuple($([:(
139+
+($([:(va[begin + $j] * vb[begin + $(i - j)]) for j in 0:i]...))
140+
) for i in 0:N]...))
141+
@inbounds TaylorScalar(v)
134142
end
135143
end
136144

137-
@generated function /(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
145+
@generated function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
146+
v = [Symbol("v$i") for i in 0:P]
138147
ex = quote
148+
$(Expr(:meta, :inline))
139149
va, vb = flatten(a), flatten(b)
140-
v1 = va[1] / vb[1]
150+
v0 = va[1] / vb[1]
151+
b0 = vb[1]
141152
end
142-
for i in 2:(N + 1)
143-
ex = quote
144-
$ex
145-
$(Symbol('v', i)) = (va[$i] -
146-
+($([:($(binomial(i - 1, j - 1)) * $(Symbol('v', j)) *
147-
vb[$(i + 1 - j)])
148-
for j in 1:(i - 1)]...))) / vb[1]
149-
end
150-
end
151-
ex = quote
152-
$ex
153-
v = tuple($([Symbol('v', i) for i in 1:(N + 1)]...))
154-
TaylorScalar(v)
153+
for i in 1:P
154+
push!(ex.args,
155+
:(
156+
$(v[begin + i]) = (va[begin + $i] -
157+
+($([:($(v[begin + j]) *
158+
vb[begin + $(i - j)])
159+
for j in 0:(i - 1)]...))) / b0
160+
)
161+
)
155162
end
163+
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
156164
return :(@inbounds $ex)
157165
end
158166

159167
for R in (Integer, Real)
160-
@eval @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: $R, T, N}
168+
@eval @generated function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
169+
v = [Symbol("v$i") for i in 0:P]
161170
ex = quote
162-
v = flatten(t)
163-
w11 = 1
164-
u1 = ^(v[1], n)
165-
end
166-
for k in 1:(N + 1)
167-
ex = quote
168-
$ex
169-
$(Symbol('p', k)) = ^(v[1], n - $(k - 1))
170-
end
171+
f = flatten(t)
172+
f0 = f[1]
173+
v0 = ^(f0, n)
171174
end
172-
for i in 2:(N + 1)
173-
subex = quote
174-
$(Symbol('w', i, 1)) = 0
175-
end
176-
for k in 2:i
177-
subex = quote
178-
$subex
179-
$(Symbol('w', i, k)) = +($([:((n * $(binomial(i - 2, j - 1)) -
180-
$(binomial(i - 2, j - 2))) *
181-
$(Symbol('w', j, k - 1)) *
182-
v[$(i + 1 - j)])
183-
for j in (k - 1):(i - 1)]...))
184-
end
185-
end
186-
ex = quote
187-
$ex
188-
$subex
189-
$(Symbol('u', i)) = +($([:($(Symbol('w', i, k)) * $(Symbol('p', k)))
190-
for k in 2:i]...))
191-
end
192-
end
193-
ex = quote
194-
$ex
195-
v = tuple($([Symbol('u', i) for i in 1:(N + 1)]...))
196-
TaylorScalar(v)
175+
for i in 1:P
176+
push!(ex.args,
177+
:(
178+
$(v[begin + i]) = +($([:(
179+
(n * $(i - j) - $j) * $(v[begin + j]) *
180+
f[begin + $(i - j)]
181+
) for j in 0:(i - 1)]...)) / ($i * f0)
182+
))
197183
end
184+
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
198185
return :(@inbounds $ex)
199186
end
200187
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
@@ -204,39 +191,14 @@ end
204191

205192
^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))
206193

207-
@generated function raise(f::T, df::TaylorScalar{T, M},
208-
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
209-
return quote
210-
$(Expr(:meta, :inline))
211-
vdf, vt = flatten(df), flatten(t)
212-
partials = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
213-
vt[$(i + 2 - j)]) for j in 1:i]...)))
214-
for i in 1:(M + 1)]...))
215-
@inbounds TaylorScalar(f, partials)
216-
end
194+
@inline function lower(t::TaylorScalar{T, P}) where {T, P}
195+
s = partials(t)
196+
TaylorScalar(ntuple(i -> s[i] * i, Val(P)))
217197
end
218-
219-
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
220-
221-
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
222-
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
223-
ex = quote
224-
vdf, vt = flatten(df), flatten(t)
225-
v1 = vt[2] / vdf[1]
226-
end
227-
for i in 2:(M + 1)
228-
ex = quote
229-
$ex
230-
$(Symbol('v', i)) = (vt[$(i + 1)] -
231-
+($([:($(binomial(i - 1, j - 1)) * $(Symbol('v', j)) *
232-
vdf[$(i + 1 - j)])
233-
for j in 1:(i - 1)]...))) / vdf[1]
234-
end
235-
end
236-
ex = quote
237-
$ex
238-
v = tuple($([Symbol('v', i) for i in 1:(M + 1)]...))
239-
TaylorScalar(f, v)
240-
end
241-
return :(@inbounds $ex)
198+
@inline function higher(t::TaylorScalar{T, P}) where {T, P}
199+
s = flatten(t)
200+
ntuple(i -> s[i] / i, Val(P + 1))
242201
end
202+
@inline raise(f, df::TaylorScalar, t) = TaylorScalar(f, higher(lower(t) * df))
203+
@inline raise(f, df::Number, t) = df * t
204+
@inline raiseinv(f, df, t) = TaylorScalar(f, higher(lower(t) / df))

0 commit comments

Comments
 (0)