ERROR: No matching distribution found for jaxlib==0.1.67

Question:

I need jaxlib==0.1.67 for a project I’m working on, but I can’t
downgrade. At the moment I have jaxlib==0.1.75 and my program keeps failing due to an error I can’t find a solution to either. I compared all versions of the important packages to another machines versions where my programs runs with no problems and the only difference is the jaxlib version (it’s still 0.1.67 on the machine where it runs). I suspect that jaxlib is the issue because the error I get when it’s not 0.1.67 is the following:

    from haiku import data_structures
  File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/data_structures.py", line 17, in <module>
    from haiku._src.data_structures import to_immutable_dict
  File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/_src/data_structures.py", line 30, in <module>
    from haiku._src import utils
  File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/_src/utils.py", line 24, in <module>
    import jax
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/__init__.py", line 108, in <module>
    from .experimental.maps import soft_pmap
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/experimental/maps.py", line 25, in <module>
    from .. import numpy as jnp
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/numpy/__init__.py", line 16, in <module>
    from . import fft
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/numpy/fft.py", line 17, in <module>
    from jax._src.numpy.fft import (
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/numpy/fft.py", line 19, in <module>
    from jax import lax
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/lax/__init__.py", line 334, in <module>
    from jax._src.lax.parallel import (
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/lax/parallel.py", line 36, in <module>
    from jax._src.numpy import lax_numpy
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 51, in <module>
    from jax import ops
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/ops/__init__.py", line 16, in <module>
    from jax._src.ops.scatter import (
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 31, in <module>
    from typing import EllipsisType
ImportError: cannot import name 'EllipsisType' from 'typing' (/usr/lib/python3.10/typing.py)

haiku and typing are the same version on both machines so guess it must be jaxlib. On both machines I’m on pip==20.0.2 and in a python 3.9.9 virtualenv.

When I try to downgrade to jaxlib==0.1.67 I get:

ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.67 (from versions: 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.10, 0.3.14, 0.3.15)
ERROR: No matching distribution found for jaxlib==0.1.67

I even tried pip install jaxlib==0.1.67 -f https://storage.googleapis.com/jax-releases/jax_releases.html and it doesn’t work.

Has anyone experienced the same problem or maybe has a clue of what could be the issue here to help me?

Asked By: haselnussbier

||

Answers:

Based on the path in the exception (/usr/lib/python3.10), it looks like you are using python 3.10. There are no python 3.10 wheels for jaxlib==0.1.67 (see pypi). You will have to use python 3.6-3.9.

If you think you are using python 3.9, then here’s a way to clear up confusion when installing packages. Use

python3.9 -m pip install

to install packages into your python 3.9 environment. Replace python3.9 with whichever python interpreter you want to use.

Answered By: jkr

The answer by @jkr is the correct answer for your question as written (how to install jaxlib 0.1.67), but I don’t think it will fix the initial error you reported.

This looks like a Python 3.10 only bug that briefly existed in the JAX source code on October 5, 2021, but was fixed and never actually made it into a jax release. If you’re seeing this, I suspect it means you installed/imported JAX from unreleased source. Further, installing a different version of jaxlib will not fix this error, because the code is in jax itself. If you’re using jaxlib 0.1.75, you might try installing jax v0.2.7 or v0.2.8, which were released around the same time, and shouldn’t contain the problematic EllipsisType import.

Another potential issue: you reported using a Python 3.9.9 virtualenv, but your traceback indicates you’re executing Python 3.10, so you probably need to check your executable paths to make sure you’re executing what you think you are.

Answered By: jakevdp
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.