Skip to content

Commit 10b5fc1

Browse files
committed
add topic extract
1 parent 2567256 commit 10b5fc1

File tree

1 file changed

+43
-29
lines changed

1 file changed

+43
-29
lines changed

python/graphy/utils/data_extractor.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,6 @@ def extract_data(self, dimension_node_names=[]):
176176
if "reference" in paper_data:
177177
del paper_data["reference"]
178178
paper_id = paper_data["id"]
179-
if paper_id not in self.facts_dict:
180-
self.facts_dict[paper_id] = sanitize_data(paper_data)
181-
182179
if not dimension_node_names:
183180
dimension_node_names = self.persist_store.get_total_states(folder)
184181
try:
@@ -187,7 +184,11 @@ def extract_data(self, dimension_node_names=[]):
187184
except ValueError:
188185
pass # Do nothing if "Paper" is not in the list
189186
for node_name in dimension_node_names:
190-
self._extract_dimension_data(paper_id, folder, node_name)
187+
self._extract_dimension_data(
188+
paper_id, paper_data, folder, node_name
189+
)
190+
if paper_id not in self.facts_dict:
191+
self.facts_dict[paper_id] = sanitize_data(paper_data)
191192
if edge_data:
192193
for edge_name, edge_pairs in edge_data.items():
193194
formatted_edges = [
@@ -198,38 +199,51 @@ def extract_data(self, dimension_node_names=[]):
198199
self.edges_dict.setdefault(edge_name, []).extend(formatted_edges)
199200

200201
def _extract_dimension_data(
201-
self, paper_id: str, folder: str, node_name: str, edge_name: str = None
202+
self,
203+
paper_id: str,
204+
paper_data: dict,
205+
folder: str,
206+
node_name: str,
207+
edge_name: str = None,
202208
):
203-
if node_name not in self.dimensions_dict:
204-
dimensions_dict = self.dimensions_dict.setdefault(node_name, {})
205-
else:
206-
dimensions_dict = self.dimensions_dict[node_name]
207-
if not edge_name:
208-
edge_name = f"Paper_Has_{node_name}"
209-
if edge_name not in self.edges_dict:
210-
edges = self.edges_dict.setdefault(edge_name, [])
211-
else:
212-
edges = self.edges_dict[edge_name]
213-
214209
data_items = self.persist_store.get_state(folder, node_name)
215210

216211
if data_items:
217212
data_items = data_items.get("data", {})
218213
if not isinstance(data_items, list):
219-
data_items = [data_items]
220-
221-
for idx, item in enumerate(data_items):
222-
data_id = hash_id(f"{paper_id}_{node_name}_{idx}")
223-
224-
if data_id not in dimensions_dict:
225-
dimensions_dict[data_id] = {"id": data_id, "node_type": "Dimension"}
226-
dimensions_dict[data_id].update(sanitize_data(item))
227-
edges.append(
228-
{
229-
"source": paper_id,
230-
"target": data_id,
214+
if len(data_items) == 1: # If there's only one item, update directly
215+
paper_data.update({node_name: next(iter(data_items.values()))})
216+
elif isinstance(data_items, dict): # Otherwise, handle dictionary
217+
for key, val in data_items.items():
218+
paper_data.update({f"{node_name}_{key}": val})
219+
else:
220+
print(f"Error: Invalid data format for {node_name} in {folder}")
221+
else:
222+
if node_name not in self.dimensions_dict:
223+
dimensions_dict = self.dimensions_dict.setdefault(node_name, {})
224+
else:
225+
dimensions_dict = self.dimensions_dict[node_name]
226+
if not edge_name:
227+
edge_name = f"Paper_Has_{node_name}"
228+
if edge_name not in self.edges_dict:
229+
edges = self.edges_dict.setdefault(edge_name, [])
230+
else:
231+
edges = self.edges_dict[edge_name]
232+
for idx, item in enumerate(data_items):
233+
data_id = hash_id(f"{paper_id}_{node_name}_{idx}")
234+
235+
if data_id not in dimensions_dict:
236+
dimensions_dict[data_id] = {
237+
"id": data_id,
238+
"node_type": "Dimension",
231239
}
232-
)
240+
dimensions_dict[data_id].update(sanitize_data(item))
241+
edges.append(
242+
{
243+
"source": paper_id,
244+
"target": data_id,
245+
}
246+
)
233247

234248
def build_graph(self, output_path=None):
235249
if output_path:

0 commit comments

Comments
 (0)