Skip to content

Commit 95502af

Browse files
committed
offer a naive unoptimized attention w/ stop graddable queries, keys, values, and fix the self reasoning transformer to only stop grad keys and values for the reasoning tokens
1 parent 3c6a5e7 commit 95502af

File tree

4 files changed

+160
-15
lines changed

4 files changed

+160
-15
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
from self_reasoning_tokens_pytorch.self_reasoning_tokens import (
22
Transformer
33
)
4+
5+
from self_reasoning_tokens_pytorch.attention_with_stop_graddable_qkv import (
6+
stop_graddable_attn_,
7+
stop_graddable_attn
8+
)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import torch
2+
from torch.autograd.function import Function
3+
4+
from einops import einsum, rearrange
5+
6+
def exists(val):
7+
return val is not None
8+
9+
# custom function
10+
11+
class StopGraddableAttentionFunction(Function):
12+
13+
@staticmethod
14+
@torch.no_grad()
15+
def forward(
16+
ctx,
17+
q,
18+
k,
19+
v,
20+
mask,
21+
attn_mask,
22+
causal: bool,
23+
q_stop_grad_mask,
24+
k_stop_grad_mask,
25+
v_stop_grad_mask,
26+
):
27+
scale = q.shape[-1] ** -0.5
28+
29+
sim = einsum(q, k, 'b h i d, b h j d -> b h i j') * scale
30+
31+
max_neg_value = -torch.finfo(sim.dtype).max
32+
33+
if exists(mask):
34+
mask = rearrange(col_mask, 'b j -> b 1 1 j')
35+
sim.masked_fill_(~mask, max_neg_value)
36+
37+
if exists(attn_mask):
38+
sim.masked_fill_(~attn_mask, max_neg_value)
39+
40+
if causal:
41+
i, j = sim.shape[-2:]
42+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
43+
sim = sim.masked_fill(causal_mask, max_neg_value)
44+
45+
attn = sim.softmax(dim = -1)
46+
47+
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
48+
49+
ctx.args = (
50+
causal,
51+
scale,
52+
mask,
53+
q_stop_grad_mask,
54+
k_stop_grad_mask,
55+
v_stop_grad_mask
56+
)
57+
58+
ctx.save_for_backward(
59+
q, k, v,
60+
attn,
61+
out
62+
)
63+
64+
return out
65+
66+
@staticmethod
67+
@torch.no_grad()
68+
def backward(ctx, do):
69+
70+
(
71+
causal,
72+
scale,
73+
mask,
74+
q_stop_grad_mask,
75+
k_stop_grad_mask,
76+
v_stop_grad_mask
77+
) = ctx.args
78+
79+
q, k, v, p, o = ctx.saved_tensors
80+
81+
# softmax D
82+
83+
D = (do * o).sum(dim = -1, keepdims = True)
84+
85+
# stop grad for values
86+
87+
p_v = p
88+
89+
if exists(v_stop_grad_mask):
90+
p_v = p_v.masked_fill(v_stop_grad_mask, 0.)
91+
92+
# dv
93+
94+
dv = einsum(p_v, do, 'b h i j, b h i d -> b h j d')
95+
96+
# prep for dq and dk
97+
98+
dp = einsum(do, v, 'b h i d, b h j d -> b h i j')
99+
ds = p * scale * (dp - D)
100+
101+
# handle stop grad masking for queries and keys
102+
103+
ds_q = ds_k = ds
104+
105+
if exists(q_stop_grad_mask):
106+
ds_q = ds_q.masked_fill(q_stop_grad_mask, 0.)
107+
108+
if exists(k_stop_grad_mask):
109+
ds_k = ds_k.masked_fill(k_stop_grad_mask, 0.)
110+
111+
# dq and dk
112+
113+
dq = einsum(ds_q, k, 'b h i j, b h j d -> b h i d')
114+
dk = einsum(ds_k, q, 'b h i j, b h i d -> b h j d')
115+
116+
return dq, dk, dv, None, None, None, None, None, None
117+
118+
# convenience method with defaults
119+
120+
stop_graddable_attn_ = StopGraddableAttentionFunction.apply
121+
122+
def stop_graddable_attn(
123+
q, k, v,
124+
mask = None,
125+
attn_mask = None,
126+
causal = False,
127+
q_stop_grad_mask = None,
128+
k_stop_grad_mask = None,
129+
v_stop_grad_mask = None
130+
):
131+
return stop_graddable_attn_(q, k, v, mask, attn_mask, causal, q_stop_grad_mask, k_stop_grad_mask, v_stop_grad_mask)

self_reasoning_tokens_pytorch/self_reasoning_tokens.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
FeedForward
1212
)
1313

14+
from self_reasoning_tokens_pytorch.attention_with_stop_graddable_qkv import (
15+
stop_graddable_attn
16+
)
17+
1418
# helper functions
1519

1620
def exists(v):
@@ -53,27 +57,32 @@ def forward(
5357

5458
q, k, v = self.to_qkv(x)
5559

56-
q = q * self.scale
57-
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
60+
if exists(stop_grad_attn_mask):
61+
out = stop_graddable_attn(
62+
q, k, v,
63+
attn_mask = attn_mask,
64+
k_stop_grad_mask = stop_grad_attn_mask,
65+
v_stop_grad_mask = stop_grad_attn_mask
66+
)
67+
68+
else:
69+
q = q * self.scale
70+
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
5871

59-
causal_mask = torch.ones((seq, seq), device = device, dtype = torch.bool).triu(1)
72+
causal_mask = torch.ones((seq, seq), device = device, dtype = torch.bool).triu(1)
6073

61-
mask_value = -torch.finfo(sim.dtype).max
62-
sim = sim.masked_fill(causal_mask, mask_value)
74+
mask_value = -torch.finfo(sim.dtype).max
75+
sim = sim.masked_fill(causal_mask, mask_value)
6376

64-
if exists(stop_grad_attn_mask):
65-
# this approach isn't quite right, as the values are not stop gradient
66-
# but will run some experiments just to see
77+
if exists(attn_mask):
78+
sim = sim.masked_fill(~attn_mask, mask_value)
6779

68-
detached_sim = sim.detach()
69-
sim = torch.where(stop_grad_attn_mask, detached_sim, sim)
80+
attn = sim.softmax(dim = -1)
7081

71-
if exists(attn_mask):
72-
sim = sim.masked_fill(~attn_mask, mask_value)
82+
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
7383

74-
attn = sim.softmax(dim = -1)
84+
# combine heads
7585

76-
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
7786
return self.to_out(out)
7887

7988
# transformer

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'self-reasoning-tokens-pytorch',
55
packages = find_packages(exclude = []),
6-
version = '0.0.1',
6+
version = '0.0.2',
77
license='MIT',
88
description = 'Self Reasoning Tokens',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)