2727import pytensor .tensor as pt
2828import xarray as xr
2929
30- from narwhals .typing import IntoFrameT , IntoLazyFrameT , IntoSeriesT
30+ from narwhals .typing import IntoFrameT , IntoSeriesT
3131from pytensor .compile import SharedVariable
3232from pytensor .compile .builders import OpFromGraph
3333from pytensor .graph .basic import Variable
@@ -174,11 +174,11 @@ def _handle_none_dims(
174174
175175@singledispatch
176176def determine_coords (
177- value ,
177+ value : typing . Any ,
178178 model : "Model" ,
179179 dims : Sequence [str | None ] | None = None ,
180180 coords : dict [str , Sequence | np .ndarray ] | None = None ,
181- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
181+ ) -> tuple [typing . Any , dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
182182 """Determine coordinate values from data or the model (via ``dims``)."""
183183 raise NotImplementedError (
184184 f"Cannot determine coordinates for data of type { type (value )} , please provide `coords` explicitly or "
@@ -192,12 +192,12 @@ def determine_array_coords(
192192 model : "Model" ,
193193 dims : Sequence [str ] | None = None ,
194194 coords : dict [str , Sequence | np .ndarray ] | None = None ,
195- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
195+ ) -> tuple [np . ndarray , dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
196196 if coords is None :
197197 coords = {}
198198
199199 if dims is None :
200- return coords , _handle_none_dims (dims , value .ndim )
200+ return value , coords , _handle_none_dims (dims , value .ndim )
201201
202202 if len (dims ) != value .ndim :
203203 raise ShapeError (
@@ -211,7 +211,7 @@ def determine_array_coords(
211211 if coord is None and dim is not None :
212212 coords [dim ] = range (size )
213213
214- return coords , _handle_none_dims (dims , value .ndim )
214+ return value , coords , _handle_none_dims (dims , value .ndim )
215215
216216
217217@determine_coords .register (xr .DataArray )
@@ -220,41 +220,41 @@ def determine_xarray_coords(
220220 model : "Model" ,
221221 dims : Sequence [str | None ] | None = None ,
222222 coords : dict [str , Sequence | np .ndarray ] | None = None ,
223- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
223+ ) -> tuple [xr . DataArray , dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
224224 if coords is None :
225225 coords = {}
226226
227227 if dims is None :
228- return coords , _handle_none_dims (dims , value .ndim )
228+ return value , coords , _handle_none_dims (dims , value .ndim )
229229
230230 for dim in dims :
231231 dim_name = dim
232232 # str is applied because dim entries may be None
233233 coords [str (dim_name )] = cast (xr .DataArray , value [dim ]).to_numpy ()
234234
235- return coords , _handle_none_dims (dims , value .ndim )
235+ return value , coords , _handle_none_dims (dims , value .ndim )
236236
237237
238238def _dataframe_agnostic_coords (
239- value : IntoFrameT | IntoLazyFrameT | nw . DataFrame | nw . LazyFrame ,
239+ value : IntoFrameT ,
240240 model : "Model" ,
241241 ndim_in : int = 2 ,
242242 dims : Sequence [str | None ] | None = None ,
243243 coords : dict [str , Sequence | np .ndarray ] | None = None ,
244- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
244+ ) -> tuple [IntoFrameT , dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
245245 if coords is None :
246246 coords = {}
247247
248248 value = cast (nw .DataFrame | nw .LazyFrame , nw .from_native (value , allow_series = False )) # type: ignore[type-var]
249249 if isinstance (value , nw .LazyFrame ):
250250 value = value .collect ()
251251
252- index = nw .maybe_get_index (value )
253- if index is not None :
254- value = value .with_columns (** {index .name : index .to_numpy ()})
255-
256252 if dims is None :
257- return coords , _handle_none_dims (dims , ndim_in )
253+ if ndim_in == 1 :
254+ value = value [value .columns [0 ]]
255+ return value .to_native (), coords , _handle_none_dims (dims , ndim_in )
256+
257+ index = nw .maybe_get_index (value )
258258
259259 if len (dims ) != ndim_in :
260260 raise ShapeError (
@@ -265,13 +265,13 @@ def _dataframe_agnostic_coords(
265265
266266 index_dim = dims [0 ]
267267 if index_dim is not None :
268- if index_dim in value . columns :
269- coords [index_dim ] = tuple (value . select ( nw . col ( index_dim )). to_numpy (). flatten () )
268+ if index is not None :
269+ coords [index_dim ] = tuple (index )
270270 elif index_dim in model .coords :
271271 coords [index_dim ] = model .coords [index_dim ] # type: ignore[assignment]
272272 else :
273273 raise ValueError (
274- f"Dimension '{ index_dim } ' not found in DataFrame columns or model coordinates. Cannot infer "
274+ f"Dimension '{ index_dim } ' not found in DataFrame index or model coordinates. Cannot infer "
275275 "index coordinates."
276276 )
277277
@@ -281,7 +281,10 @@ def _dataframe_agnostic_coords(
281281 select_expr = nw .exclude (index_dim ) if index_dim is not None else nw .all ()
282282 coords [column_dim ] = value .select (select_expr ).columns
283283
284- return coords , _handle_none_dims (dims , ndim_in )
284+ if ndim_in == 1 :
285+ value = value [value .columns [0 ]]
286+
287+ return value .to_native (), coords , _handle_none_dims (dims , ndim_in )
285288
286289
287290def _series_agnostic_coords (
@@ -319,7 +322,9 @@ def determine_dataframe_coords(
319322 model : "Model" ,
320323 dims : Sequence [str ] | None = None ,
321324 coords : dict [str , Sequence | np .ndarray ] | None = None ,
322- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
325+ ) -> tuple [
326+ IntoFrameT , dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]
327+ ]:
323328 return _dataframe_agnostic_coords (value , model = model , dims = dims , coords = coords )
324329
325330 except ImportError :
@@ -453,7 +458,7 @@ def Data(
453458
454459 new_dims : Sequence [str | None ] | Sequence [None ] | None
455460 if infer_dims_and_coords :
456- coords , new_dims = determine_coords (value , model , dims )
461+ value , coords , new_dims = determine_coords (value , model , dims )
457462 else :
458463 new_dims = dims
459464
0 commit comments