Skip to content

Commit

Permalink
take the first step towards variable length states starting with lang…
Browse files Browse the repository at this point in the history
…uage
  • Loading branch information
lucidrains committed Nov 13, 2024
1 parent 1395440 commit a4027ed
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 37 deletions.
114 changes: 78 additions & 36 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention = torch.compile(flex_attention)

def create_pizero_attn_mask(prefix_causal_length):
def create_pizero_attn_mask(prefix_causal_length, mask: Bool['b n']):
# the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side

def inner(batch_index, head_index, query_index, key_index):
def mask_fn(batch_index, head_index, query_index, key_index):
return (
mask[batch_index, key_index] and # variable length states
query_index >= key_index and # causal
key_index >= prefix_causal_length # bidirectional
)

return inner
return mask_fn

def softclamp_score_mod(value):
def identity(score, b, h, q, k):
Expand All @@ -84,12 +85,17 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

# tensor helpers

def softclamp(t, value):
if value <= 0.:
return t

return (t / value).tanh() * value

def max_neg_value(t):
return -torch.finfo(t.dtype).max

def pack_with_inverse(t, pattern):
packed, packed_shape = pack(t, pattern)

Expand Down Expand Up @@ -142,6 +148,7 @@ def forward_actions_with_cached_state(
actions,
cached_state_keys_values: tuple[Tensor, Tensor],
rotary_emb = None,
mask: Bool['b n'] | None = None,
actions_value_residual: Tensor | None = None,
return_keys_values = False,
flex_attn_fn: Callable | None = None
Expand All @@ -159,7 +166,7 @@ def forward_actions_with_cached_state(
k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av)))

if exists(rotary_emb):
q = apply_rotary_emb(rotary_emb, q)
q = apply_rotary_emb(rotary_emb, q, freqs_seq_dim = -2)
k = apply_rotary_emb(rotary_emb, k)

elif exists(self.rotary_emb):
Expand All @@ -176,6 +183,9 @@ def forward_actions_with_cached_state(

sim = softclamp(sim, self.softclamp_value)

if exists(mask):
sim = einx.where('b j, b h i j, -> b h i j', mask, sim, max_neg_value(sim))

attn = sim.softmax(dim = -1)

out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
Expand All @@ -196,6 +206,7 @@ def forward(
multimodal_seq,
actions,
rotary_emb = None,
mask: Bool['b n'] | None = None,
actions_value_residual: Tensor | None = None,
return_keys_values = False,
flex_attn_fn: Callable | None = None
Expand Down Expand Up @@ -238,9 +249,12 @@ def forward(

causal_mask = torch.ones(sim.shape[-2:], dtype = torch.bool, device = device).triu(1)

if exists(mask):
causal_mask = einx.logical_or('b j, i j -> b 1 i j', ~mask, causal_mask)

causal_mask[..., seq_len:] = False # actions have bidirectional attention, lining up with Transfusion paper

sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
sim = sim.masked_fill(causal_mask, max_neg_value(sim))

attn = sim.softmax(dim = -1)

Expand Down Expand Up @@ -320,6 +334,10 @@ def __init__(
self.to_beta = LinearNoBias(dim_cond, dim)

def forward(self, actions, cond):

if cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')

normed = self.norm(actions)
gamma = self.to_gamma(cond)
beta = self.to_beta(cond)
Expand All @@ -340,6 +358,10 @@ def __init__(
self.to_adaln_zero_gamma = adaln_zero_gamma_linear

def forward(self, actions, cond):

if cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')

gamma = self.to_adaln_zero_gamma(cond)
return actions * gamma.sigmoid()

Expand All @@ -366,6 +388,7 @@ def __init__(
num_action_register_tokens = 4,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict(),
lm_pad_id = -1,
lm_loss_weight = 1.,
flow_loss_weight = 1.,
direction_loss_weight = 0.,
Expand Down Expand Up @@ -447,6 +470,10 @@ def __init__(
self.state_to_logits = LinearNoBias(dim, num_tokens)
self.actions_to_pred_flow = LinearNoBias(dim, dim_action_input)

# the language token id padding id, for fine-tuning as well as taking care of the masking on top of causal mask

self.lm_pad_id = lm_pad_id

# loss related

self.lm_loss_weight = lm_loss_weight
Expand All @@ -465,6 +492,13 @@ def __init__(
def device(self):
return next(self.parameters()).device

@beartype
def pretrained_vlm_weights(
self,
weights: dict[str, Tensor]
):
raise NotImplementedError

@torch.inference_mode()
def sample(
self,
Expand All @@ -474,9 +508,10 @@ def sample(
trajectory_length: int,
reward_tokens = None,
steps = 18,
batch_size = 1,
show_pbar = True
):
batch_size = token_ids.shape[0]

was_training = self.training
self.eval()

Expand Down Expand Up @@ -562,7 +597,7 @@ def forward(
flow = actions - noise
padded_times = rearrange(times, 'b -> b 1 1')

actions = noise * (1. - padded_times) + padded_times * actions
actions = noise.lerp(actions, padded_times)

# actions

Expand Down Expand Up @@ -617,34 +652,13 @@ def forward(

state_tokens, inverse_packed_states = pack_with_inverse([visual_tokens, language_tokens, joint_state_tokens, reward_tokens], 'b * d')

# prepare maybe flex attention

flex_attn_fn = None

if self.use_flex_attn and state_tokens.is_cuda:

block_mask = None

if not inferencing:
prefix_length = state_tokens.shape[-2]
seq_len = prefix_length + action_tokens.shape[-2]

block_mask = create_block_mask(
create_pizero_attn_mask(prefix_length),
Q_LEN = seq_len,
KV_LEN = seq_len,
device = state_tokens.device
)
# take care of masking for variable lengthed states, starting with the language tokens

score_mod_fn = softclamp_score_mod(self.attn_softclamp_value)
# which then leads to proper rotary embeddings

flex_attn_fn = partial(
flex_attention,
block_mask = block_mask,
score_mod = score_mod
)
command_length = token_ids.shape[-1]

# prepare rotary embeddings
language_mask = token_ids != self.lm_pad_id

action_with_registers_length = action_tokens.shape[-2]

Expand All @@ -654,12 +668,39 @@ def forward(
state_length = state_tokens.shape[-2]

total_seq_length = action_with_registers_length + state_length
mask = F.pad(language_mask, (state_length - command_length - 1, 1 + action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later

# rotary embeddings

seq = torch.arange(total_seq_length, device = self.device)
seq = torch.cumsum(mask.float(), dim = -1)
rotary_emb = self.rotary_emb(seq)

rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d')

# prepare maybe flex attention

flex_attn_fn = None

if not inferencing and self.use_flex_attn and state_tokens.is_cuda:

prefix_length = state_tokens.shape[-2]
seq_len = prefix_length + action_tokens.shape[-2]

block_mask = create_block_mask(
create_pizero_attn_mask(prefix_length, mask = mask),
Q_LEN = seq_len,
KV_LEN = seq_len,
device = state_tokens.device
)

score_mod_fn = softclamp_score_mod(self.attn_softclamp_value)

flex_attn_fn = partial(
flex_attention,
block_mask = block_mask,
score_mod = score_mod
)

# state keys and values for caching during inference

cached_state_key_values_iter = iter(default(cached_state_keys_values, []))
Expand All @@ -680,7 +721,7 @@ def forward(

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, return_keys_values = True)
(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, return_keys_values = True)

state_cached_keys_values.append((state_keys, state_values))

Expand Down Expand Up @@ -708,7 +749,7 @@ def forward(

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), rotary_emb = rotary_emb, return_keys_values = True)
actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), rotary_emb = rotary_emb, mask = mask, return_keys_values = True)

state_cached_keys_values.append((state_keys, state_values))

Expand Down Expand Up @@ -768,7 +809,8 @@ def forward(

language_loss = F.cross_entropy(
rearrange(language_logits[:, :-1], 'b n l -> b l n'),
labels
labels,
ignore_index = self.lm_pad_id
)

# loss breakdown
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.12"
version = "0.0.14"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit a4027ed

Please sign in to comment.