Skip to content

Commit 42bf06d

Browse files
authored
Merge pull request #666 from hyp1231/master
FIX: code format
2 parents 9591235 + b70dcdf commit 42bf06d

File tree

9 files changed

+17
-22
lines changed

9 files changed

+17
-22
lines changed

recbole/data/dataloader/abstract_dataloader.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,6 @@ def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE,
5555
if self.real_time is None:
5656
self.real_time = True
5757

58-
self.join = self.dataset.join
59-
self.history_item_matrix = self.dataset.history_item_matrix
60-
self.history_user_matrix = self.dataset.history_user_matrix
61-
self.inter_matrix = self.dataset.inter_matrix
62-
self.get_user_feature = self.dataset.get_user_feature
63-
self.get_item_feature = self.dataset.get_item_feature
64-
6558
for dataset_attr in self.dataset._dataloader_apis:
6659
try:
6760
flag = hasattr(self.dataset, dataset_attr)

recbole/data/dataset/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ def _check_field(self, *field_names):
11731173
if getattr(self, field_name, None) is None:
11741174
raise ValueError(f'{field_name} isn\'t set.')
11751175

1176+
@dlapi.set()
11761177
def join(self, df):
11771178
"""Given interaction feature, join user/item feature into it.
11781179
@@ -1429,6 +1430,7 @@ def save(self, filepath):
14291430
if df is not None:
14301431
df.to_csv(os.path.join(filepath, f'{name}.csv'))
14311432

1433+
@dlapi.set()
14321434
def get_user_feature(self):
14331435
"""
14341436
Returns:
@@ -1440,6 +1442,7 @@ def get_user_feature(self):
14401442
else:
14411443
return self.user_feat
14421444

1445+
@dlapi.set()
14431446
def get_item_feature(self):
14441447
"""
14451448
Returns:
@@ -1536,6 +1539,7 @@ def _create_graph(self, tensor_feat, source_field, target_field, form='dgl', val
15361539
else:
15371540
raise NotImplementedError(f'Graph format [{form}] has not been implemented.')
15381541

1542+
@dlapi.set()
15391543
def inter_matrix(self, form='coo', value_field=None):
15401544
"""Get sparse matrix that describe interactions between user_id and item_id.
15411545
@@ -1617,6 +1621,7 @@ def _history_matrix(self, row, value_field=None):
16171621

16181622
return torch.LongTensor(history_matrix), torch.FloatTensor(history_value), torch.LongTensor(history_len)
16191623

1624+
@dlapi.set()
16201625
def history_item_matrix(self, value_field=None):
16211626
"""Get dense matrix describe user's history interaction records.
16221627
@@ -1641,6 +1646,7 @@ def history_item_matrix(self, value_field=None):
16411646
"""
16421647
return self._history_matrix(row='user', value_field=value_field)
16431648

1649+
@dlapi.set()
16441650
def history_user_matrix(self, value_field=None):
16451651
"""Get dense matrix describe item's history interaction records.
16461652

recbole/data/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from recbole.config import EvalSetting
2020
from recbole.data.dataloader import *
2121
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
22-
from recbole.utils import ModelType
22+
from recbole.utils import ModelType, ensure_dir
2323

2424

2525
def create_dataset(config):
@@ -216,8 +216,7 @@ def save_datasets(save_path, name, dataset):
216216
raise ValueError(f'Length of name {name} should equal to length of dataset {dataset}.')
217217
for i, d in enumerate(dataset):
218218
cur_path = os.path.join(save_path, name[i])
219-
if not os.path.isdir(cur_path):
220-
os.makedirs(cur_path)
219+
ensure_dir(cur_path)
221220
d.save(cur_path)
222221

223222

recbole/model/abstract_recommender.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(self, config, dataset):
9191
self.n_items = dataset.num(self.ITEM_ID)
9292

9393
# load parameters info
94-
self.batch_size = config['train_batch_size']
9594
self.device = config['device']
9695

9796

@@ -145,7 +144,6 @@ def __init__(self, config, dataset):
145144
self.n_relations = dataset.num(self.RELATION_ID)
146145

147146
# load parameters info
148-
self.batch_size = config['train_batch_size']
149147
self.device = config['device']
150148

151149

recbole/model/general_recommender/dgcf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, config, dataset):
7474
self.n_layers = config['n_layers']
7575
self.reg_weight = config['reg_weight']
7676
self.cor_weight = config['cor_weight']
77-
n_batch = dataset.dataset.inter_num // self.batch_size + 1
77+
n_batch = dataset.dataset.inter_num // config['train_batch_size'] + 1
7878
self.cor_batch_size = int(max(self.n_users / n_batch, self.n_items / n_batch))
7979
# ensure embedding can be divided into <n_factors> intent
8080
assert self.embedding_size % self.n_factors == 0

recbole/model/sequential_recommender/gru4rec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def _init_weights(self, module):
7070
if isinstance(module, nn.Embedding):
7171
xavier_normal_(module.weight)
7272
elif isinstance(module, nn.GRU):
73-
xavier_uniform_(self.gru_layers.weight_hh_l0)
74-
xavier_uniform_(self.gru_layers.weight_ih_l0)
73+
xavier_uniform_(module.weight_hh_l0)
74+
xavier_uniform_(module.weight_ih_l0)
7575

7676
def forward(self, item_seq, item_seq_len):
7777
item_seq_emb = self.item_embedding(item_seq)

recbole/model/sequential_recommender/ksr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def _init_weights(self, module):
8484
if isinstance(module, nn.Embedding):
8585
xavier_normal_(module.weight)
8686
elif isinstance(module, nn.GRU):
87-
xavier_uniform_(self.gru_layers.weight_hh_l0)
88-
xavier_uniform_(self.gru_layers.weight_ih_l0)
87+
xavier_uniform_(module.weight_hh_l0)
88+
xavier_uniform_(module.weight_ih_l0)
8989

9090
def _get_kg_embedding(self, head):
9191
"""Difference:

recbole/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from logging import getLogger
1717
from time import time
1818

19-
import matplotlib.pyplot as plt
2019
import numpy as np
2120
import torch
2221
import torch.optim as optim
@@ -64,7 +63,7 @@ class Trainer(AbstractTrainer):
6463
6564
Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information
6665
for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on.
67-
More information can be found in [placeholder]. `model` is the instantiated object of a Model Class.
66+
`model` is the instantiated object of a Model Class.
6867
6968
"""
7069

@@ -422,6 +421,7 @@ def plot_train_loss(self, show=True, save_path=None):
422421
save_path (str, optional): The data path to save the figure, default: None.
423422
If it's None, it will not be saved.
424423
"""
424+
import matplotlib.pyplot as plt
425425
epochs = list(self.train_loss_dict.keys())
426426
epochs.sort()
427427
values = [float(self.train_loss_dict[epoch]) for epoch in epochs]

recbole/utils/logger.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
import os
1313

14-
from recbole.utils.utils import get_local_time
14+
from recbole.utils.utils import get_local_time, ensure_dir
1515

1616

1717
def init_logger(config):
@@ -30,8 +30,7 @@ def init_logger(config):
3030
"""
3131
LOGROOT = './log/'
3232
dir_name = os.path.dirname(LOGROOT)
33-
if not os.path.exists(dir_name):
34-
os.makedirs(dir_name)
33+
ensure_dir(dir_name)
3534

3635
logfilename = '{}-{}.log'.format(config['model'], get_local_time())
3736

0 commit comments

Comments
 (0)