Skip to content

Commit efa3cd6

Browse files
committed
Removing other polytypes, standardize procedure.
1 parent 8fa9eef commit efa3cd6

File tree

2 files changed

+15
-71
lines changed

2 files changed

+15
-71
lines changed

patsy/contrasts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def _code_either(self, intercept, levels):
264264
# quadratic, etc., functions of the raw scores, and then use 'qr' to
265265
# orthogonalize each column against those to its left.
266266
scores -= scores.mean()
267-
raw_poly = Polynomial.vander(scores, n - 1, 'poly')
267+
raw_poly = Polynomial.vander(scores, n - 1)
268268
alpha, norm, beta = Polynomial.gen_qr(raw_poly, n - 1)
269269
q = Polynomial.apply_qr(raw_poly, n - 1, alpha, norm, beta)
270270
q[:, 0] = 1

patsy/polynomials.py

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
# R-compatible poly function
66

77
# These are made available in the patsy.* namespace
8-
__all__ = ["poly"]
9-
108
import numpy as np
119

1210
from patsy.util import have_pandas, no_pickling, assert_no_pickling
1311
from patsy.state import stateful_transform
1412

13+
__all__ = ["poly"]
14+
1515
if have_pandas:
1616
import pandas
1717

18+
1819
class Poly(object):
19-
"""poly(x, degree=3, polytype='poly', raw=False, scaler=None)
20+
"""poly(x, degree=3, raw=False)
2021
2122
Generates an orthogonal polynomial transformation of x of degree.
2223
Generic usage is something along the lines of::
@@ -26,29 +27,17 @@ class Poly(object):
2627
to fit ``y`` as a function of ``x``, with a 4th degree polynomial.
2728
2829
:arg degree: The number of degrees for the polynomial expansion.
29-
:arg polytype: Either poly (the default), legendre, laguerre, hermite, or
30-
hermanite_e.
3130
:arg raw: When raw is False (the default), will return orthogonal
3231
polynomials.
33-
:arg scaler: Choice of 'qr' (default when raw=False) for QR-
34-
decomposition or 'standardize'.
3532
3633
.. versionadded:: 0.4.1
3734
"""
3835
def __init__(self):
3936
self._tmp = {}
4037

41-
def memorize_chunk(self, x, degree=3, polytype='poly', raw=False,
42-
scaler=None):
43-
if not raw and (scaler is None):
44-
scaler = 'qr'
45-
if scaler not in ('qr', 'standardize', None):
46-
raise ValueError('input to \'scaler\' %s is not a valid '
47-
'scaling technique' % scaler)
38+
def memorize_chunk(self, x, degree=3, raw=False):
4839
args = {"degree": degree,
49-
"raw": raw,
50-
"scaler": scaler,
51-
'polytype': polytype
40+
"raw": raw
5241
}
5342
self._tmp["args"] = args
5443
# XX: check whether we need x values before saving them
@@ -80,19 +69,12 @@ def memorize_finish(self):
8069

8170
n = args['degree']
8271
self.degree = n
83-
self.scaler = args['scaler']
8472
self.raw = args['raw']
85-
self.polytype = args['polytype']
86-
87-
if self.scaler is not None:
88-
raw_poly = self.vander(scores, n, self.polytype)
8973

90-
if self.scaler == 'qr':
74+
if not self.raw:
75+
raw_poly = self.vander(scores, n)
9176
self.alpha, self.norm, self.beta = self.gen_qr(raw_poly, n)
9277

93-
if self.scaler == 'standardize':
94-
self.mean, self.var = self.gen_standardize(raw_poly)
95-
9678
def transform(self, x, degree=3, polytype='poly', raw=False, scaler=None):
9779
if have_pandas:
9880
if isinstance(x, (pandas.Series, pandas.DataFrame)):
@@ -120,23 +102,17 @@ def transform(self, x, degree=3, polytype='poly', raw=False, scaler=None):
120102
return p
121103

122104
@staticmethod
123-
def vander(x, n, polytype):
124-
v_func = {'poly': np.polynomial.polynomial.polyvander,
125-
'cheb': np.polynomial.chebyshev.chebvander,
126-
'legendre': np.polynomial.legendre.legvander,
127-
'laguerre': np.polynomial.laguerre.lagvander,
128-
'hermite': np.polynomial.hermite.hermvander,
129-
'hermite_e': np.polynomial.hermite_e.hermevander}
130-
raw_poly = v_func[polytype](x, n)
105+
def vander(x, n):
106+
raw_poly = np.polynomial.polynomial.polyvander(x, n)
131107
return raw_poly
132108

133109
@staticmethod
134110
def gen_qr(raw_poly, n):
111+
x = raw_poly[:, 1]
112+
q, r = np.linalg.qr(raw_poly)
135113
# Q is now orthognoal of degree n. To match what R is doing, we
136114
# need to use the three-term recurrence technique to calculate
137115
# new alpha, beta, and norm.
138-
x = raw_poly[:, 1]
139-
q, r = np.linalg.qr(raw_poly)
140116
alpha = (np.sum(x.reshape((-1, 1)) * q[:, :n] ** 2, axis=0) /
141117
np.sum(q[:, :n] ** 2, axis=0))
142118

@@ -147,10 +123,6 @@ def gen_qr(raw_poly, n):
147123
beta = (norm[1:] / norm[:n]) ** 2
148124
return alpha, norm, beta
149125

150-
@staticmethod
151-
def gen_standardize(raw_poly):
152-
return raw_poly.mean(axis=0), raw_poly.var(axis=0)
153-
154126
@staticmethod
155127
def apply_qr(x, n, alpha, norm, beta):
156128
# This is where the three-term recurrence is unwound for the QR
@@ -166,13 +138,6 @@ def apply_qr(x, n, alpha, norm, beta):
166138
p[:, i + 1] = (p[:, i + 1] - (beta[i - 1] * p[:, i - 1]))
167139
p /= norm
168140
return p
169-
170-
@staticmethod
171-
def apply_standardize(x, mean, var):
172-
x[:, 1:] = ((x[:, 1:] - mean[1:]) / (var[1:] ** 0.5))
173-
return x
174-
175-
176141
__getstate__ = no_pickling
177142

178143
poly = stateful_transform(Poly)
@@ -215,24 +180,6 @@ def test_poly_compat():
215180
start_idx = stop_idx + 1
216181
assert tests_ran == R_poly_num_tests
217182

218-
def test_poly_smoke():
219-
# Test that standardized values match.
220-
x = np.arange(27)
221-
vanders = ['poly', 'cheb', 'legendre', 'laguerre', 'hermite', 'hermite_e']
222-
scalers = ['raw', 'qr', 'standardize']
223-
for v in vanders:
224-
p1 = poly(x, polytype=v, scaler='standardize')
225-
p2 = poly(x, polytype=v, raw=True)
226-
p2 = (p2 - p2.mean(axis=0)) / p2.std(axis=0)
227-
np.testing.assert_allclose(p1, p2)
228-
229-
# Don't have tests for all this... so just make sure it works.
230-
for v in vanders:
231-
for s in scalers:
232-
if s == 'raw':
233-
poly(x, raw=True, polytype=v)
234-
else:
235-
poly(x, scaler=s, polytype=v)
236183

237184
def test_poly_errors():
238185
from nose.tools import assert_raises
@@ -245,8 +192,5 @@ def test_poly_errors():
245192
assert_raises(ValueError, poly, x, degree=0)
246193
assert_raises(ValueError, poly, x, degree=3.5)
247194

248-
#Invalid Poly Type
249-
assert_raises(KeyError, poly, x, polytype='foo')
250-
251-
#Invalid scaling type
252-
assert_raises(ValueError, poly, x, scaler='bar')
195+
p = poly(np.arange(1, 10), degree=3)
196+
assert_no_pickling(p)

0 commit comments

Comments
 (0)