Skip to content

Commit d290dd3

Browse files
committed
ENH: R poly compatibility
Travis fixes
1 parent 8b6c712 commit d290dd3

File tree

5 files changed

+286
-0
lines changed

5 files changed

+286
-0
lines changed

doc/API-reference.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ Spline regression
198198
.. autofunction:: cc
199199
.. autofunction:: te
200200

201+
Orthogonal Polynomial
202+
---------------------
203+
204+
.. autofunction:: poly
205+
201206
Working with formulas programmatically
202207
--------------------------------------
203208

patsy/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,8 @@ def _reexport(mod):
113113
import patsy.mgcv_cubic_splines
114114
_reexport(patsy.mgcv_cubic_splines)
115115

116+
import patsy.poly
117+
_reexport(patsy.poly)
118+
116119
# XX FIXME: we aren't exporting any of the explicit parsing interface
117120
# yet. Need to figure out how to do that.

patsy/poly.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# This file is part of Patsy
2+
# Copyright (C) 2012-2013 Nathaniel Smith <[email protected]>
3+
# See file LICENSE.txt for license information.
4+
5+
# R-compatible poly function
6+
7+
# These are made available in the patsy.* namespace
8+
__all__ = ["poly"]
9+
10+
import numpy as np
11+
12+
from patsy.util import have_pandas, no_pickling, assert_no_pickling
13+
from patsy.state import stateful_transform
14+
15+
if have_pandas:
16+
import pandas
17+
18+
class Poly(object):
19+
"""poly(x, degree=1, raw=False)
20+
21+
Generates an orthogonal polynomial transformation of x of degree.
22+
Generic usage is something along the lines of::
23+
24+
y ~ 1 + poly(x, 4)
25+
26+
to fit ``y`` as a function of ``x``, with a 4th degree polynomial.
27+
28+
:arg degree: The number of degrees for the polynomial expansion.
29+
:arg raw: When raw is False (the default), will return orthogonal
30+
polynomials.
31+
32+
.. versionadded:: 0.4.1
33+
"""
34+
def __init__(self):
35+
self._tmp = {}
36+
self._degree = None
37+
self._raw = None
38+
39+
def memorize_chunk(self, x, degree=3, raw=False):
40+
args = {"degree": degree,
41+
"raw": raw
42+
}
43+
self._tmp["args"] = args
44+
# XX: check whether we need x values before saving them
45+
x = np.atleast_1d(x)
46+
if x.ndim == 2 and x.shape[1] == 1:
47+
x = x[:, 0]
48+
if x.ndim > 1:
49+
raise ValueError("input to 'poly' must be 1-d, "
50+
"or a 2-d column vector")
51+
# There's no better way to compute exact quantiles than memorizing
52+
# all data.
53+
x = np.array(x, dtype=float)
54+
self._tmp.setdefault("xs", []).append(x)
55+
56+
def memorize_finish(self):
57+
tmp = self._tmp
58+
args = tmp["args"]
59+
del self._tmp
60+
61+
if args["degree"] < 1:
62+
raise ValueError("degree must be greater than 0 (not %r)"
63+
% (args["degree"],))
64+
if int(args["degree"]) != args["degree"]:
65+
raise ValueError("degree must be an integer (not %r)"
66+
% (self._degree,))
67+
68+
# These are guaranteed to all be 1d vectors by the code above
69+
scores = np.concatenate(tmp["xs"])
70+
scores_mean = scores.mean()
71+
# scores -= scores_mean
72+
self.scores_mean = scores_mean
73+
n = args['degree']
74+
self.degree = n
75+
raw_poly = scores.reshape((-1, 1)) ** np.arange(n + 1).reshape((1, -1))
76+
raw = args['raw']
77+
self.raw = raw
78+
if not raw:
79+
q, r = np.linalg.qr(raw_poly)
80+
# Q is now orthognoal of degree n. To match what R is doing, we
81+
# need to use the three-term recurrence technique to calculate
82+
# new alpha, beta, and norm.
83+
84+
self.alpha = (np.sum(scores.reshape((-1, 1)) * q[:, :n] ** 2,
85+
axis=0) /
86+
np.sum(q[:, :n] ** 2, axis=0))
87+
88+
# For reasons I don't understand, the norms R uses are based off
89+
# of the diagonal of the r upper triangular matrix.
90+
91+
self.norm = np.linalg.norm(q * np.diag(r), axis=0)
92+
self.beta = (self.norm[1:] / self.norm[:n]) ** 2
93+
94+
def transform(self, x, degree=3, raw=False):
95+
if have_pandas:
96+
if isinstance(x, (pandas.Series, pandas.DataFrame)):
97+
to_pandas = True
98+
idx = x.index
99+
else:
100+
to_pandas = False
101+
else:
102+
to_pandas = False
103+
x = np.array(x, ndmin=1).flatten()
104+
105+
if self.raw:
106+
n = self.degree
107+
p = x.reshape((-1, 1)) ** np.arange(n + 1).reshape((1, -1))
108+
else:
109+
# This is where the three-term recurrance technique is unwound.
110+
111+
p = np.empty((x.shape[0], self.degree + 1))
112+
p[:, 0] = 1
113+
114+
for i in np.arange(self.degree):
115+
p[:, i + 1] = (x - self.alpha[i]) * p[:, i]
116+
if i > 0:
117+
p[:, i + 1] = (p[:, i + 1] -
118+
(self.beta[i - 1] * p[:, i - 1]))
119+
p /= self.norm
120+
121+
p = p[:, 1:]
122+
if to_pandas:
123+
p = pandas.DataFrame(p)
124+
p.index = idx
125+
return p
126+
127+
__getstate__ = no_pickling
128+
129+
poly = stateful_transform(Poly)
130+
131+
132+
def test_poly_compat():
133+
from patsy.test_state import check_stateful
134+
from patsy.test_poly_data import (R_poly_test_x,
135+
R_poly_test_data,
136+
R_poly_num_tests)
137+
lines = R_poly_test_data.split("\n")
138+
tests_ran = 0
139+
start_idx = lines.index("--BEGIN TEST CASE--")
140+
while True:
141+
if not lines[start_idx] == "--BEGIN TEST CASE--":
142+
break
143+
start_idx += 1
144+
stop_idx = lines.index("--END TEST CASE--", start_idx)
145+
block = lines[start_idx:stop_idx]
146+
test_data = {}
147+
for line in block:
148+
key, value = line.split("=", 1)
149+
test_data[key] = value
150+
# Translate the R output into Python calling conventions
151+
kwargs = {
152+
# integer
153+
"degree": int(test_data["degree"]),
154+
# boolen
155+
"raw": (test_data["raw"] == 'TRUE')
156+
}
157+
# Special case: in R, setting intercept=TRUE increases the effective
158+
# dof by 1. Adjust our arguments to match.
159+
# if kwargs["df"] is not None and kwargs["include_intercept"]:
160+
# kwargs["df"] += 1
161+
output = np.asarray(eval(test_data["output"]))
162+
# Do the actual test
163+
check_stateful(Poly, False, R_poly_test_x, output, **kwargs)
164+
tests_ran += 1
165+
# Set up for the next one
166+
start_idx = stop_idx + 1
167+
assert tests_ran == R_poly_num_tests
168+
169+
170+
def test_poly_errors():
171+
from nose.tools import assert_raises
172+
x = np.arange(27)
173+
# Invalid input shape
174+
assert_raises(ValueError, poly, x.reshape((3, 3, 3)))
175+
assert_raises(ValueError, poly, x.reshape((3, 3, 3)), raw=True)
176+
# Invalid degree
177+
assert_raises(ValueError, poly, x, degree=-1)
178+
assert_raises(ValueError, poly, x, degree=0)
179+
assert_raises(ValueError, poly, x, degree=3.5)

patsy/test_poly_data.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# This file auto-generated by tools/get-R-bs-test-vectors.R
2+
# Using: R version 3.2.4 Revised (2016-03-16 r70336)
3+
import numpy as np
4+
R_poly_test_x = np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ])
5+
R_poly_test_data = """
6+
--BEGIN TEST CASE--
7+
degree=1
8+
raw=TRUE
9+
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]).reshape((20, 1, ), order="F")
10+
--END TEST CASE--
11+
--BEGIN TEST CASE--
12+
degree=1
13+
raw=FALSE
14+
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, ]).reshape((20, 1, ), order="F")
15+
--END TEST CASE--
16+
--BEGIN TEST CASE--
17+
degree=3
18+
raw=TRUE
19+
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, ]).reshape((20, 3, ), order="F")
20+
--END TEST CASE--
21+
--BEGIN TEST CASE--
22+
degree=3
23+
raw=FALSE
24+
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, ]).reshape((20, 3, ), order="F")
25+
--END TEST CASE--
26+
--BEGIN TEST CASE--
27+
degree=5
28+
raw=TRUE
29+
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, 1, 5.0625, 25.62890625, 129.746337890625, 656.84083557128906, 3325.2567300796509, 16834.112196028233, 85222.692992392927, 431439.8832739892, 2184164.4090745705, 11057332.320940012, 55977744.87475881, 283387333.4284665, 1434648375.4816115, 7262907400.875659, 36768468716.933022, 186140372879.47342, 942335637702.33411, 4770574165868.0674, 24151031714707.086, 1, 7.59375, 57.6650390625, 437.89389038085938, 3325.2567300796509, 25251.168294042349, 191751.05923288409, 1456109.6060497134, 11057332.320940012, 83966617.31213823, 637621500.21404958, 4841938267.2504387, 36768468716.933022, 279210559319.21014, 2120255184830.252, 16100687809804.727, 122264598055704.64, 928446791485507, 7050392822843070, 53538920498464552, ]).reshape((20, 5, ), order="F")
30+
--END TEST CASE--
31+
--BEGIN TEST CASE--
32+
degree=5
33+
raw=FALSE
34+
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, 0.11925766326375063, 0.11701962699862156, 0.11367531238125347, 0.10868744714732725, 0.10126981942884175, 0.090287103769210786, 0.074134201646975206, 0.050620044131431986, 0.016933017097416861, -0.030116712154368355, -0.093138533517390085, -0.17160263551697441, -0.25618209006285081, -0.3183631162695052, -0.29707753517866498, -0.10102478727647804, 0.30185248746535442, 0.55289166632880227, -0.46108564710186972, 0.081962667419115426, -0.12626707822019206, -0.12250155553682644, -0.11689136915447108, -0.10856147160045609, -0.096257598068575617, -0.078227654788373013, -0.052128116579684983, -0.015063001240831148, 0.035988153544508683, 0.10280803884977513, 0.18263307034840112, 0.26144732880503613, 0.30325203347309243, 0.24116709207723347, -0.00082575540196283526, -0.37830141983168153, -0.42887161757203512, 0.55207091753656046, -0.17171017635275559, 0.016240179713238136, ]).reshape((20, 5, ), order="F")
35+
--END TEST CASE--
36+
"""
37+
R_poly_num_tests = 6

tools/get-R-poly-test-vectors.R

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
cat("# This file auto-generated by tools/get-R-bs-test-vectors.R\n")
2+
cat(sprintf("# Using: %s\n", R.Version()$version.string))
3+
cat("import numpy as np\n")
4+
5+
options(digits=20)
6+
library(splines)
7+
x <- (1.5)^(0:19)
8+
9+
MISSING <- "MISSING"
10+
11+
is.missing <- function(obj) {
12+
length(obj) == 1 && obj == MISSING
13+
}
14+
15+
pyprint <- function(arr) {
16+
if (is.missing(arr)) {
17+
cat("None\n")
18+
} else {
19+
cat("np.array([")
20+
for (val in arr) {
21+
cat(val)
22+
cat(", ")
23+
}
24+
cat("])")
25+
if (!is.null(dim(arr))) {
26+
cat(".reshape((")
27+
for (size in dim(arr)) {
28+
cat(sprintf("%s, ", size))
29+
}
30+
cat("), order=\"F\")")
31+
}
32+
cat("\n")
33+
}
34+
}
35+
36+
num.tests <- 0
37+
dump.poly <- function(degree, raw) {
38+
cat("--BEGIN TEST CASE--\n")
39+
cat(sprintf("degree=%s\n", degree))
40+
cat(sprintf("raw=%s\n", raw))
41+
42+
args <- list(x=x, degree=degree, raw=raw)
43+
44+
result <- do.call(poly, args)
45+
46+
cat("output=")
47+
pyprint(result)
48+
cat("--END TEST CASE--\n")
49+
assign("num.tests", num.tests + 1, envir=.GlobalEnv)
50+
}
51+
52+
cat("R_poly_test_x = ")
53+
pyprint(x)
54+
cat("R_poly_test_data = \"\"\"\n")
55+
56+
for (degree in c(1, 3, 5)) {
57+
for (raw in c(TRUE, FALSE)) {
58+
dump.poly(degree, raw)
59+
}
60+
}
61+
cat("\"\"\"\n")
62+
cat(sprintf("R_poly_num_tests = %s\n", num.tests))

0 commit comments

Comments
 (0)