1212
1313
1414# see file data/rectangles/output.stan
15- @pytest .fixture (scope = "module" )
16- def rect_data ():
15+ @pytest .fixture (scope = "module" , params = [ True , False ], ids = [ "use_object" , "use_dtype" ] )
16+ def rect_data (request ):
1717 files = [DATA / "rectangles" / f"output_{ i } .csv" for i in range (1 , 5 )]
1818 header , data = read_csv (files )
1919 params = parse_header (header )
20- yield stan_variables (params , data )
20+ yield stan_variables (params , data , object = request . param )
2121
2222
2323def test_basic_shapes (rect_data ):
@@ -91,43 +91,93 @@ def test_basic_values(rect_data):
9191
9292
9393# see file data/tuples/output.stan
94- @pytest .fixture (scope = "module" )
95- def tuple_data ():
94+ @pytest .fixture (scope = "module" , params = [ True , False ], ids = [ "use_object" , "use_dtype" ] )
95+ def tuple_data (request ):
9696 files = [DATA / "tuples" / f"output_{ i } .csv" for i in range (1 , 5 )]
9797 header , data = read_csv (files )
9898 params = parse_header (header )
99- yield stan_variables (params , data )
99+ yield stan_variables (params , data , object = request . param )
100100
101101
102102def test_tuple_shapes (tuple_data ):
103- assert isinstance (tuple_data ["pair" ][0 , 0 ], tuple )
104103 assert len (tuple_data ["pair" ][0 , 0 ]) == 2
105104
106- assert isinstance (tuple_data ["nested" ][0 , 0 ], tuple )
107105 assert len (tuple_data ["nested" ][0 , 0 ]) == 2
108- assert isinstance (tuple_data ["nested" ][0 , 0 ][1 ], tuple )
109106 assert len (tuple_data ["nested" ][0 , 0 ][1 ]) == 2
110107
111108 assert tuple_data ["arr_pair" ].shape == (4 , 1000 , 2 )
112- assert isinstance (tuple_data ["arr_pair" ][0 , 0 , 0 ], tuple )
113109
114110 assert tuple_data ["arr_very_nested" ].shape == (4 , 1000 , 3 )
111+
112+ assert tuple_data ["arr_2d_pair" ].shape == (4 , 1000 , 3 , 2 )
113+
114+ assert tuple_data ["ultimate" ].shape == (4 , 1000 , 2 , 3 )
115+ assert tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][0 ].shape == (2 ,)
116+ assert tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][0 ][0 ][1 ].shape == (2 ,)
117+ assert tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][1 ].shape == (4 , 5 )
118+
119+
120+ def check_tuple_shapes_objects (tuple_data ):
121+ assert isinstance (tuple_data ["pair" ][0 , 0 ], tuple )
122+
123+ assert isinstance (tuple_data ["nested" ][0 , 0 ], tuple )
124+ assert isinstance (tuple_data ["nested" ][0 , 0 ][1 ], tuple )
125+
126+ assert isinstance (tuple_data ["arr_pair" ][0 , 0 , 0 ], tuple )
127+
115128 assert isinstance (tuple_data ["arr_very_nested" ][0 , 0 , 0 ], tuple )
116129 assert isinstance (tuple_data ["arr_very_nested" ][0 , 0 , 0 ][0 ], tuple )
117130 assert isinstance (tuple_data ["arr_very_nested" ][0 , 0 , 0 ][0 ][1 ], tuple )
118131
119- assert tuple_data ["arr_2d_pair" ].shape == (4 , 1000 , 3 , 2 )
120132 assert isinstance (tuple_data ["arr_2d_pair" ][0 , 0 , 0 , 0 ], tuple )
121133
122- assert tuple_data ["ultimate" ].shape == (4 , 1000 , 2 , 3 )
123134 assert isinstance (tuple_data ["ultimate" ][0 , 0 , 0 , 0 ], tuple )
124- assert tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][0 ].shape == (2 ,)
125135 assert isinstance (tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][0 ][0 ], tuple )
126- assert tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][0 ][0 ][1 ].shape == (2 ,)
127- assert tuple_data ["ultimate" ][0 , 0 , 0 , 0 ][1 ].shape == (4 , 5 )
136+
137+
138+ def check_tuple_shapes_custom_dtypes (tuple_data ):
139+ for value in tuple_data .values ():
140+ assert not value .dtype .hasobject
141+
142+ pair_dtype = np .dtype ([("1" , "f8" ), ("2" , "f8" )])
143+ assert tuple_data ["pair" ].dtype == pair_dtype
144+
145+ nested_dtype = np .dtype ([("1" , "f8" ), ("2" , [("1" , "f8" ), ("2" , "c16" )])])
146+ assert tuple_data ["nested" ].dtype == nested_dtype
147+ assert tuple_data ["nested" ][0 , 0 ][1 ].dtype == nested_dtype [1 ]
148+
149+ assert tuple_data ["arr_pair" ].dtype == pair_dtype
150+
151+ very_nested_dtype = np .dtype (
152+ [
153+ ("1" , nested_dtype ),
154+ ("2" , "f8" ),
155+ ]
156+ )
157+ assert tuple_data ["arr_very_nested" ].dtype == very_nested_dtype
158+ assert tuple_data ["arr_very_nested" ][0 , 0 , 0 ][0 ].dtype == nested_dtype
159+ assert tuple_data ["arr_very_nested" ][0 , 0 , 0 ][0 ][1 ].dtype == nested_dtype [1 ]
160+
161+ ultimate_dtype = np .dtype (
162+ [
163+ ("1" , ([("1" , "f8" ), ("2" , "(2,)f8" )], (2 ,))),
164+ ("2" , "(4,5)f8" ),
165+ ]
166+ )
167+ assert tuple_data ["ultimate" ].dtype == ultimate_dtype
168+
169+
170+ def test_tuple_dtypes (tuple_data ):
171+ if isinstance (tuple_data ["pair" ][0 , 0 ], tuple ):
172+ check_tuple_shapes_objects (tuple_data )
173+ else :
174+ check_tuple_shapes_custom_dtypes (tuple_data )
128175
129176
130177def assert_tuple_equal (t1 , t2 ):
178+ if hasattr (t1 , "dtype" ) and t1 .dtype .kind == "V" :
179+ t1 = t1 .tolist ()
180+
131181 assert len (t1 ) == len (t2 )
132182 for x , y in zip (t1 , t2 ):
133183 if isinstance (x , tuple ):
@@ -140,7 +190,7 @@ def check_tuples(tuple_data, chain, draw):
140190 base = tuple_data ["base" ][chain , draw ]
141191 base_i = tuple_data ["base_i" ][chain , draw ]
142192 pair_exp = (base , 2 * base )
143- np . testing . assert_almost_equal (tuple_data ["pair" ][chain , draw ], pair_exp )
193+ assert_tuple_equal (tuple_data ["pair" ][chain , draw ], pair_exp )
144194 nested_exp = (base * 3 , (base_i , 4j * base ))
145195 assert_tuple_equal (tuple_data ["nested" ][chain , draw ], nested_exp )
146196
0 commit comments