Skip to content

Commit eb48cda

Browse files
authored
Merge pull request #51 from Chen001117/dev
Add with torch no grad while calculating reward function
2 parents 0757ab7 + b67a31b commit eb48cda

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

openrl/envs/nlp/rewards/kl_penalty.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ def __call__(
6666
self._ref_net, input_ids, past_model_kwargs
6767
)
6868

69-
output = self._ref_net(output_hidden_states=True, **model_inputs)
70-
output["past_key_values"] = None
71-
next_token_logits = output.logits[:, -1, :]
72-
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
73-
action_input = actions.to(next_token_logits.device)
74-
ref_log_prob = dist.log_prob(action_input)
69+
with torch.no_grad():
70+
output = self._ref_net(output_hidden_states=True, **model_inputs)
71+
output["past_key_values"] = None
72+
next_token_logits = output.logits[:, -1, :]
73+
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
74+
action_input = actions.to(next_token_logits.device)
75+
ref_log_prob = dist.log_prob(action_input)
7576

7677
ref_log_prob = ref_log_prob.reshape(action_log_probs.shape)
7778
kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy()

0 commit comments

Comments
 (0)