Skip to content

Commit d40c361

Browse files
committed
fix(gfql): preserve edge filters in cudf same-path
1 parent 277e3f5 commit d40c361

File tree

1 file changed

+25
-133
lines changed

1 file changed

+25
-133
lines changed

graphistry/compute/gfql/cudf_executor.py

Lines changed: 25 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,19 @@ def _materialize_filtered(self, path_state: "_PathState") -> Plottable:
555555
"""Build result graph from allowed node/edge ids and refresh alias frames."""
556556

557557
nodes_df = self.inputs.graph._nodes
558-
edges_df = self.inputs.graph._edges
559558
node_id = self._node_column
560559
edge_id = self._edge_column
561560
src = self._source_column
562561
dst = self._destination_column
563562

563+
edge_frames = [
564+
self.forward_steps[idx]._edges
565+
for idx, op in enumerate(self.inputs.chain)
566+
if isinstance(op, ASTEdge) and self.forward_steps[idx]._edges is not None
567+
]
568+
concatenated_edges = self._concat_frames(edge_frames)
569+
edges_df = concatenated_edges if concatenated_edges is not None else self.inputs.graph._edges
570+
564571
if nodes_df is None or edges_df is None or node_id is None or src is None or dst is None:
565572
raise ValueError("Graph bindings are incomplete for same-path execution")
566573

@@ -603,6 +610,23 @@ def _alias_for_step(self, step_index: int) -> Optional[str]:
603610
return alias
604611
return None
605612

613+
@staticmethod
614+
def _concat_frames(frames: Sequence[DataFrameT]) -> Optional[DataFrameT]:
615+
"""Concatenate a sequence of pandas or cuDF frames, preserving type."""
616+
617+
if not frames:
618+
return None
619+
first = frames[0]
620+
try:
621+
if first.__class__.__module__.startswith("cudf"):
622+
import cudf # type: ignore
623+
624+
return cudf.concat(frames, ignore_index=True)
625+
except Exception:
626+
# Fall back to pandas concat when cuDF is unavailable or mismatched
627+
pass
628+
return pd.concat(frames, ignore_index=True)
629+
606630

607631
def _apply_ready_clauses(self) -> None:
608632
if not self.inputs.where:
@@ -805,135 +829,3 @@ def _validate_where_aliases(
805829
raise ValueError(
806830
f"WHERE references aliases with no node/edge bindings: {missing_str}"
807831
)
808-
809-
# --- GPU helpers ---------------------------------------------------------------
810-
811-
def _compute_allowed_tags(self) -> Dict[str, Set[Any]]:
812-
"""Seed allowed ids from alias frames (post-forward pruning)."""
813-
814-
out: Dict[str, Set[Any]] = {}
815-
for alias, binding in self.inputs.alias_bindings.items():
816-
frame = self.alias_frames.get(alias)
817-
if frame is None:
818-
continue
819-
id_col = self._node_column if binding.kind == "node" else self._edge_column
820-
if id_col is None or id_col not in frame.columns:
821-
continue
822-
out[alias] = self._series_values(frame[id_col])
823-
return out
824-
825-
@dataclass
826-
class _PathState:
827-
allowed_nodes: Dict[int, Set[Any]]
828-
allowed_edges: Dict[int, Set[Any]]
829-
830-
def _backward_prune(self, allowed_tags: Dict[str, Set[Any]]) -> "_PathState":
831-
"""Propagate allowed ids backward across edges to enforce path coherence."""
832-
833-
node_indices: List[int] = []
834-
edge_indices: List[int] = []
835-
for idx, op in enumerate(self.inputs.chain):
836-
if isinstance(op, ASTNode):
837-
node_indices.append(idx)
838-
elif isinstance(op, ASTEdge):
839-
edge_indices.append(idx)
840-
if not node_indices:
841-
raise ValueError("Same-path executor requires at least one node step")
842-
if len(node_indices) != len(edge_indices) + 1:
843-
raise ValueError("Chain must alternate node/edge steps for same-path execution")
844-
845-
allowed_nodes: Dict[int, Set[Any]] = {}
846-
allowed_edges: Dict[int, Set[Any]] = {}
847-
848-
# Seed node allowances from tags or full frames
849-
for idx in node_indices:
850-
node_alias = self._alias_for_step(idx)
851-
frame = self.forward_steps[idx]._nodes
852-
if frame is None or self._node_column is None:
853-
continue
854-
if node_alias and node_alias in allowed_tags:
855-
allowed_nodes[idx] = set(allowed_tags[node_alias])
856-
else:
857-
allowed_nodes[idx] = self._series_values(frame[self._node_column])
858-
859-
# Walk edges backward
860-
for edge_idx, right_node_idx in reversed(list(zip(edge_indices, node_indices[1:]))):
861-
edge_alias = self._alias_for_step(edge_idx)
862-
left_node_idx = node_indices[node_indices.index(right_node_idx) - 1]
863-
edges_df = self.forward_steps[edge_idx]._edges
864-
if edges_df is None:
865-
continue
866-
867-
# Filter by destination
868-
filtered = edges_df
869-
if self._destination_column and self._destination_column in filtered.columns:
870-
allowed_dst = allowed_nodes.get(right_node_idx)
871-
if allowed_dst is not None:
872-
filtered = filtered[
873-
filtered[self._destination_column].isin(list(allowed_dst))
874-
]
875-
876-
# Filter by edge tags if supplied
877-
if edge_alias and edge_alias in allowed_tags:
878-
allowed_edge_ids = allowed_tags[edge_alias]
879-
if self._edge_column and self._edge_column in filtered.columns:
880-
filtered = filtered[
881-
filtered[self._edge_column].isin(list(allowed_edge_ids))
882-
]
883-
884-
# Capture allowed edges
885-
if self._edge_column and self._edge_column in filtered.columns:
886-
allowed_edges[edge_idx] = self._series_values(filtered[self._edge_column])
887-
888-
# Propagate allowed sources
889-
if self._source_column and self._source_column in filtered.columns:
890-
allowed_src = self._series_values(filtered[self._source_column])
891-
current = allowed_nodes.get(left_node_idx, set())
892-
allowed_nodes[left_node_idx] = current & allowed_src if current else allowed_src
893-
894-
return self._PathState(allowed_nodes=allowed_nodes, allowed_edges=allowed_edges)
895-
896-
def _materialize_filtered(self, path_state: "_PathState") -> Plottable:
897-
"""Build result graph from allowed node/edge ids and refresh alias frames."""
898-
899-
nodes_df = self.inputs.graph._nodes
900-
edges_df = self.inputs.graph._edges
901-
node_id = self._node_column
902-
edge_id = self._edge_column
903-
src = self._source_column
904-
dst = self._destination_column
905-
906-
if nodes_df is None or edges_df is None or node_id is None or src is None or dst is None:
907-
raise ValueError("Graph bindings are incomplete for same-path execution")
908-
909-
allowed_node_ids: Set[Any] = set().union(*path_state.allowed_nodes.values()) if path_state.allowed_nodes else set()
910-
allowed_edge_ids: Set[Any] = set().union(*path_state.allowed_edges.values()) if path_state.allowed_edges else set()
911-
912-
filtered_nodes = nodes_df[nodes_df[node_id].isin(list(allowed_node_ids))] if allowed_node_ids else nodes_df.iloc[0:0]
913-
filtered_edges = edges_df
914-
filtered_edges = filtered_edges[
915-
filtered_edges[dst].isin(list(allowed_node_ids))
916-
] if allowed_node_ids else filtered_edges.iloc[0:0]
917-
if allowed_edge_ids and edge_id in filtered_edges.columns:
918-
filtered_edges = filtered_edges[filtered_edges[edge_id].isin(list(allowed_edge_ids))]
919-
920-
# Refresh alias frames based on filtered data
921-
for alias, binding in self.inputs.alias_bindings.items():
922-
frame = (
923-
filtered_nodes if binding.kind == "node" else filtered_edges
924-
)
925-
id_col = self._node_column if binding.kind == "node" else self._edge_column
926-
if id_col is None or id_col not in frame.columns:
927-
continue
928-
required = set(self.inputs.column_requirements.get(alias, set()))
929-
required.add(id_col)
930-
subset = frame[[c for c in frame.columns if c in required]].copy()
931-
self.alias_frames[alias] = subset
932-
933-
return self._materialize_from_oracle(filtered_nodes, filtered_edges)
934-
935-
def _alias_for_step(self, step_index: int) -> Optional[str]:
936-
for alias, binding in self.inputs.alias_bindings.items():
937-
if binding.step_index == step_index:
938-
return alias
939-
return None

0 commit comments

Comments
 (0)