Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error running latent_ode.py #6

Open
tkamthroche opened this issue Jul 15, 2021 · 2 comments
Open

Error running latent_ode.py #6

tkamthroche opened this issue Jul 15, 2021 · 2 comments

Comments

@tkamthroche
Copy link

tkamthroche commented Jul 15, 2021

tried running the script on physionet data and get the following error after a few iterations, can you comment on this and also a bit more on what is the expected output:
TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

Upon running it again, it would just hang here:

 python latent_ode.py --reg r3 --lam 1e-2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
~/.conda/envs/Neural_ODE/lib/python3.8/site-packages/jax/_src/random.py:511: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation
  warnings.warn(msg, FutureWarning)

conda environment:

_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
absl-py                   0.13.0                    <pip>
backcall                  0.2.0                     <pip>
ca-certificates           2021.5.30            ha878542_0    conda-forge
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cycler                    0.10.0                    <pip>
Cython                    0.29.19                   <pip>
debugpy                   1.3.0                     <pip>
dm-haiku                  0.0.5.dev0                <pip>
flatbuffers               2.0                       <pip>
future                    0.18.2                    <pip>
ipykernel                 6.0.0                     <pip>
ipython                   7.25.0                    <pip>
ipython-genutils          0.2.0                     <pip>
jax                       0.2.17                    <pip>
jaxlib                    0.1.68                    <pip>
jedi                      0.18.0                    <pip>
jmp                       0.0.2                     <pip>
joblib                    0.15.1                    <pip>
jupyter-client            6.1.12                    <pip>
jupyter-core              4.7.1                     <pip>
kiwisolver                1.2.0                     <pip>
ld_impl_linux-64          2.36.1               hea4e1c9_0    conda-forge
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 9.3.0               h2828fa1_19    conda-forge
libgomp                   9.3.0               h2828fa1_19    conda-forge
libstdcxx-ng              9.3.0               h6de172a_19    conda-forge
matplotlib                3.2.1                     <pip>
matplotlib-inline         0.1.2                     <pip>
ncurses                   6.2                  h58526e2_4    conda-forge
numpy                     1.21.0                    <pip>
openssl                   1.1.1k               h7f98852_0    conda-forge
opt-einsum                3.3.0                     <pip>
parso                     0.8.2                     <pip>
pexpect                   4.8.0                     <pip>
phate                     1.0.7                     <pip>
pickleshare               0.7.5                     <pip>
pip                       21.1.3             pyhd8ed1ab_0    conda-forge
POT                       0.7.0                     <pip>
prompt-toolkit            3.0.19                    <pip>
ptyprocess                0.7.0                     <pip>
Pygments                  2.9.0                     <pip>
pyparsing                 2.4.7                     <pip>
python                    3.8.10          h49503c6_1_cpython    conda-forge
python-dateutil           2.8.1                     <pip>
python_abi                3.8                      2_cp38    conda-forge
pyzmq                     22.1.0                    <pip>
readline                  8.1                  h46c0cb4_0    conda-forge
s-gd2                     1.8                       <pip>
scikit-learn              0.23.1                    <pip>
scipy                     1.4.1                     <pip>
setuptools                49.6.0           py38h578d9bd_3    conda-forge
six                       1.15.0                    <pip>
sklearn                   0.0                       <pip>
sqlite                    3.36.0               h9cd32fc_0    conda-forge
tabulate                  0.8.9                     <pip>
threadpoolctl             2.1.0                     <pip>
tk                        8.6.10               h21135ba_1    conda-forge
torch                     1.5.0                     <pip>
torchdiffeq               0.0.1                     <pip>
tornado                   6.1                       <pip>
traitlets                 5.0.5                     <pip>
wcwidth                   0.2.5                     <pip>
wheel                     0.36.2             pyhd3deb0d_0    conda-forge
xz                        5.2.5                h516909a_1    conda-forge
zlib                      1.2.11            h516909a_1010    conda-forge
@itamblyn
Copy link

Bump for this. Exactly same error.

@jacobjinkelly
Copy link
Owner

Hello! Sorry for the delayed reply. I'm having some trouble reproducing this error actually. I used the preprocessed data available in the release. My conda environment export is:

channels:
  - defaults
dependencies:
  - ca-certificates=2021.7.5=hecd8cb5_1
  - certifi=2021.5.30=py38hecd8cb5_0
  - libcxx=12.0.0=h2f01273_0
  - libffi=3.3=hb1e8313_2
  - ncurses=6.2=h0a44026_1
  - openssl=1.1.1l=h9ed2024_0
  - python=3.8.11=h88f2d9e_1
  - readline=8.1=h9ed2024_0
  - setuptools=58.0.4=py38hecd8cb5_0
  - sqlite=3.36.0=hce871da_0
  - tk=8.6.10=hb0a8c7a_0
  - wheel=0.37.0=pyhd3eb1b0_1
  - xz=5.2.5=h1de35cc_0
  - zlib=1.2.11=h1de35cc_3
  - pip:
    - absl-py==0.14.0
    - dm-haiku==0.0.5.dev0
    - flatbuffers==2.0
    - jax==0.2.20
    - jaxlib==0.1.71
    - jmp==0.0.2
    - numpy==1.21.2
    - opt-einsum==3.3.0
    - pip==21.2.4
    - scipy==1.7.1
    - six==1.16.0
    - tabulate==0.8.9

I ran the command python latent_ode.py --reg r2 --lam 1e-2 --test_freq 1 on my laptop and ran python latent_ode.py --reg r2 --lam 1e-2 --test_freq 1, so far I have after ~10 minutes of running on my macbook:

Iter 0001 | Loss 798.138111 | Likelihood -808.377092 | KL 2.490536 | MSE 0.165348 | Enc. r 0.000000 | Dec. r 0.001278 | Enc. NFE 0.000000 | Dec. NFE 31.824688
Iter 0002 | Loss 551.387941 | Likelihood -566.549929 | KL 1.965105 | MSE 0.116983 | Enc. r 0.000000 | Dec. r 0.005880 | Enc. NFE 0.000000 | Dec. NFE 31.839688
Iter 0003 | Loss 495.621342 | Likelihood -497.331870 | KL 1.669389 | MSE 0.103139 | Enc. r 0.000000 | Dec. r 0.020152 | Enc. NFE 0.000000 | Dec. NFE 34.642188
Iter 0004 | Loss 332.830424 | Likelihood -335.500099 | KL 1.934213 | MSE 0.070773 | Enc. r 0.000000 | Dec. r 0.016797 | Enc. NFE 0.000000 | Dec. NFE 32.999062
Iter 0005 | Loss 222.494621 | Likelihood -230.846079 | KL 2.180931 | MSE 0.049842 | Enc. r 0.000000 | Dec. r 0.010237 | Enc. NFE 0.000000 | Dec. NFE 35.735312

In particular, I used r2 since it uses less memory. Using r3 is possible, but I typically only ran this on a remote cluster where I had access to more RAM.

When you ran the first time and got the error TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?, was this after you ran the data processing code yourself?

In summary, my suggestions are:

  1. See if my conda environment is different than yours and if this fixes this error.
  2. Set --test_freq 1 to confirm code is running (the default is --test_freq 640
  3. Try --reg r2 since it's faster and uses less memory
  4. Try running on a remote machine with more RAM, especially if you want to use --reg r3, e.g. try Google Collab?

Please let me know if any of this is helpful, or if you have any other issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants