Skip to content

Commit b70dcdf

Browse files
committed
FIX: utilize dlapi to simplify codes
1 parent 2952493 commit b70dcdf

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

recbole/data/dataloader/abstract_dataloader.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,6 @@ def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE,
5555
if self.real_time is None:
5656
self.real_time = True
5757

58-
self.join = self.dataset.join
59-
self.history_item_matrix = self.dataset.history_item_matrix
60-
self.history_user_matrix = self.dataset.history_user_matrix
61-
self.inter_matrix = self.dataset.inter_matrix
62-
self.get_user_feature = self.dataset.get_user_feature
63-
self.get_item_feature = self.dataset.get_item_feature
64-
6558
for dataset_attr in self.dataset._dataloader_apis:
6659
try:
6760
flag = hasattr(self.dataset, dataset_attr)

recbole/data/dataset/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ def _check_field(self, *field_names):
11731173
if getattr(self, field_name, None) is None:
11741174
raise ValueError(f'{field_name} isn\'t set.')
11751175

1176+
@dlapi.set()
11761177
def join(self, df):
11771178
"""Given interaction feature, join user/item feature into it.
11781179
@@ -1429,6 +1430,7 @@ def save(self, filepath):
14291430
if df is not None:
14301431
df.to_csv(os.path.join(filepath, f'{name}.csv'))
14311432

1433+
@dlapi.set()
14321434
def get_user_feature(self):
14331435
"""
14341436
Returns:
@@ -1440,6 +1442,7 @@ def get_user_feature(self):
14401442
else:
14411443
return self.user_feat
14421444

1445+
@dlapi.set()
14431446
def get_item_feature(self):
14441447
"""
14451448
Returns:
@@ -1536,6 +1539,7 @@ def _create_graph(self, tensor_feat, source_field, target_field, form='dgl', val
15361539
else:
15371540
raise NotImplementedError(f'Graph format [{form}] has not been implemented.')
15381541

1542+
@dlapi.set()
15391543
def inter_matrix(self, form='coo', value_field=None):
15401544
"""Get sparse matrix that describe interactions between user_id and item_id.
15411545
@@ -1617,6 +1621,7 @@ def _history_matrix(self, row, value_field=None):
16171621

16181622
return torch.LongTensor(history_matrix), torch.FloatTensor(history_value), torch.LongTensor(history_len)
16191623

1624+
@dlapi.set()
16201625
def history_item_matrix(self, value_field=None):
16211626
"""Get dense matrix describe user's history interaction records.
16221627
@@ -1641,6 +1646,7 @@ def history_item_matrix(self, value_field=None):
16411646
"""
16421647
return self._history_matrix(row='user', value_field=value_field)
16431648

1649+
@dlapi.set()
16441650
def history_user_matrix(self, value_field=None):
16451651
"""Get dense matrix describe item's history interaction records.
16461652

0 commit comments

Comments
 (0)