Skip to content

Commit

Permalink
allow for attending to nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 13, 2024
1 parent 19f2367 commit 5ccd33b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
35 changes: 29 additions & 6 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ def inverse(out, inv_pattern = None):

return packed, inverse

def pad_at_dim(
t,
pad: tuple[int, int],
*,
dim = -1,
value = 0.
):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)

# losses

def direction_loss(pred, target, dim = -1):
Expand Down Expand Up @@ -138,7 +149,7 @@ def __init__(
self.to_qkv = LinearNoBias(dim, 3 * dim_inner)
self.to_out = LinearNoBias(dim_inner, dim)

self.to_actions_qkv = LinearNoBias(dim, 3 * dim_inner)
self.to_actions_qkvg = LinearNoBias(dim, 4 * dim_inner)
self.to_actions_out = LinearNoBias(dim_inner, dim)

self.softclamp_value = softclamp_value
Expand All @@ -153,9 +164,9 @@ def forward_actions_with_cached_state(
return_keys_values = False,
flex_attn_fn: Callable | None = None
):
aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1)
aq, ak, av, ag = self.to_actions_qkvg(actions).chunk(4, dim = -1)

aq, ak, av = tuple(self.split_heads(t) for t in (aq, ak, av))
aq, ak, av, ag = tuple(self.split_heads(t) for t in (aq, ak, av, ag))

if exists(actions_value_residual):
av = 0.5 * (av + actions_value_residual)
Expand Down Expand Up @@ -190,6 +201,10 @@ def forward_actions_with_cached_state(

out = einsum(attn, v, 'b h i j, b h j d -> b h i d')

# gate

out = out * ag.sigmoid()

# merge attention heads

out = self.merge_heads(out)
Expand Down Expand Up @@ -219,9 +234,9 @@ def forward(

mq, mk, mv = self.to_qkv(multimodal_seq).chunk(3, dim = -1)

aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1)
aq, ak, av, ag = self.to_actions_qkvg(actions).chunk(4, dim = -1)

mq, mk, mv, aq, ak, av = tuple(self.split_heads(t) for t in (mq, mk, mv, aq, ak, av))
mq, mk, mv, aq, ak, av, ag = tuple(self.split_heads(t) for t in (mq, mk, mv, aq, ak, av, ag))

if exists(actions_value_residual):
av = 0.5 * (av + actions_value_residual)
Expand Down Expand Up @@ -260,6 +275,12 @@ def forward(

out = einsum(attn, v, 'b h i j, b h j d -> b h i d')

# gating of values, used in alphafold line of work

gates = pad_at_dim(ag.sigmoid(), (out.shape[-2] - ag.shape[-2], 0), value = 1., dim = -2)

out = out * gates

# merge attention heads

out = self.merge_heads(out)
Expand Down Expand Up @@ -413,7 +434,9 @@ def __init__(

self.vit = vit

self.maybe_to_image_tokens = nn.Linear(vit_dim, dim) if vit_dim != dim else nn.Identity()
assert not exists(vit) and not exists(vit_dim), '`vit_dim` must be passed in if `vit` is made available for image encoding at forward'

self.maybe_to_image_tokens = nn.Linear(vit_dim, dim) if exists(vit_dim) and vit_dim != dim else nn.Identity()

# embedding

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.14"
version = "0.0.15"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 5ccd33b

Please sign in to comment.