Pytorch-Lightning ModelCheckpoint get paths of saved checkpoints

Question:

I am using PytorchLightning and beside others a ModelCheckpoint which saves models with a formated filename like filename="model_{epoch}-{val_acc:.2f}"

In a process I want to load these checkpoints again, for simplicity lets say I want only the best via save_top_k=N.
As the filename is dynamic I wonder how can I retrieve the checkpoint easily.
Is there a built in attribute or via the trainer that gives the saved checkpoints?
For example like

checkpoint_callback.get_top_k_paths()

I know I can do it with glob and model_dir but wondering if there is a one line solution built in somehwere.

Asked By: Daraan

||

Answers:

you can retrieve the best model path after training from the checkpoint

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

To find all the checkpoints you can get the list of files in the dirpath where the checkpoints are saved.

Answered By: Aniket Maurya