File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -66,12 +66,13 @@ def __call__(
6666 self ._ref_net , input_ids , past_model_kwargs
6767 )
6868
69- output = self ._ref_net (output_hidden_states = True , ** model_inputs )
70- output ["past_key_values" ] = None
71- next_token_logits = output .logits [:, - 1 , :]
72- dist = self ._action_dist .proba_distribution (action_logits = next_token_logits )
73- action_input = actions .to (next_token_logits .device )
74- ref_log_prob = dist .log_prob (action_input )
69+ with torch .no_grad ():
70+ output = self ._ref_net (output_hidden_states = True , ** model_inputs )
71+ output ["past_key_values" ] = None
72+ next_token_logits = output .logits [:, - 1 , :]
73+ dist = self ._action_dist .proba_distribution (action_logits = next_token_logits )
74+ action_input = actions .to (next_token_logits .device )
75+ ref_log_prob = dist .log_prob (action_input )
7576
7677 ref_log_prob = ref_log_prob .reshape (action_log_probs .shape )
7778 kl_div = action_log_probs .copy () - ref_log_prob .detach ().cpu ().numpy ()
You can’t perform that action at this time.
0 commit comments