ImportError: cannot import name 'Checkpoint' from 'ray.air'


I’m trying to follow this tutorial to tune hyperparameters in PyTorch using Ray, copy-pasted everything but I get the following error:

ImportError: cannot import name 'Checkpoint' from 'ray.air'

from this line of import:

from ray.air import Checkpoint

I installed ray using pip install -U "ray[tune]" as suggested on the official website. After getting the error, to be sure, I also tried a more general pip install ray, which did not fix anything.
I have version ray==2.9.0 installed.

Any help, please?

Asked By: Irene Ferfoglia



Try to install older version 2.7.0:

pip install ray[tune]==2.7.0

Update :

For the newest version the Ray AIR session is replaced with a Ray Train context object.
You can import Checkpoint using :

from ray.train import Checkpoint

You need to adjust your code as follow:

from ray import air, train

# Ray Train methods and classes:               ->
air.session.get_dataset_shard    -> train.get_dataset_shard
air.session.get_checkpoint       -> train.get_checkpoint
air.Checkpoint                   -> train.Checkpoint
air.Result                       -> train.Result

# Ray Train configurations:
air.config.CheckpointConfig      -> train.CheckpointConfig
air.config.FailureConfig         -> train.FailureConfig
air.config.RunConfig             -> train.RunConfig
air.config.ScalingConfig         -> train.ScalingConfig

# Ray TrainContext methods:
air.session.get_experiment_name  -> train.get_context().get_experiment_name
air.session.get_trial_name       -> train.get_context().get_trial_name
air.session.get_trial_id         -> train.get_context().get_trial_id
air.session.get_trial_resources  -> train.get_context().get_trial_resources
air.session.get_trial_dir        -> train.get_context().get_trial_dir
air.session.get_world_size       -> train.get_context().get_world_size
air.session.get_world_rank       -> train.get_context().get_world_rank
air.session.get_local_rank       -> train.get_context().get_local_rank
air.session.get_local_world_size -> train.get_context().get_local_world_size
air.session.get_node_rank        -> train.get_context().get_node_rank

For more informations see :

Answered By: حمزة نبيل
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.