Skip to content

Commit

Permalink
add linting and fix score mod function in flex attention pathway
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 20, 2024
1 parent 79ce31d commit 4628c28
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Ruff
on: [push, pull_request]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install uv
python -m uv pip install ruff
- name: Lint with Ruff
run: |
ruff check pi_zero_pytorch/
5 changes: 2 additions & 3 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def forward(
images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w')

with torch.no_grad():
self.vit.eval()
self.vit.eval()
visual_tokens = self.vit(images)

if is_multiple_images:
Expand Down Expand Up @@ -880,7 +880,6 @@ def forward(
else:
state_length = state_tokens.shape[-2]

total_seq_length = action_with_registers_length + state_length
mask = F.pad(language_mask, (state_length - command_length - 1, 1 + action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later

# rotary embeddings
Expand Down Expand Up @@ -911,7 +910,7 @@ def forward(
flex_attn_fn = partial(
flex_attention,
block_mask = block_mask,
score_mod = score_mod
score_mod = score_mod_fn
)

# state keys and values for caching during inference
Expand Down
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.26"
version = "0.0.27"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -45,6 +45,7 @@ Repository = "https://github.com/lucidrains/pi-zero-pytorch"
examples = []
test = [
"pytest",
"ruff>=0.4.2",
"vit-pytorch>=1.8.7"
]

Expand All @@ -53,6 +54,19 @@ pythonpath = [
"."
]

[tool.ruff]
line-length = 1000

lint.ignore = [
"F722", # for jaxtyping shape annotation
"F401",
"F821"
]

lint.extend-select = [
"W291"
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down

0 comments on commit 4628c28

Please sign in to comment.