Append model checkpoints to existing file in PyTorch

Question:

In PyTorch, it is possible to save model checkpoints as follows:

import torch

# Create a model
model = torch.nn.Sequential(
    torch.nn.Linear(1, 50),
    torch.nn.Tanh(),
    torch.nn.Linear(50, 1)
)

# ... some training here
# Save checkpoint
torch.save(network.state_dict(), 'checkpoint.pt')

During my training procedure, I save a checkpoint every 100 epochs or so. Currently this results in a folder with many files, e.g.

checkpoint0.pt
checkpoint100.pt
checkpoint200.pt

I was wondering if it was possible to append checkpoints to an existing file, so I don’t clutter my disk with small files but instead have only a single file called checkpoints.pt. I currently have implemented this as follows:

import torch

# Create a model
model = torch.nn.Sequential(
    torch.nn.Linear(1, 50),
    torch.nn.Tanh(),
    torch.nn.Linear(50, 1)
)

# ... some training here
# Save 1st checkpoint
data = {'0': model.state_dict()}
torch.save(data, 'checkpoints.pt')

# ... some training here
# Save 2nd checkpoint
data = torch.load('checkpoints.pt')
data['100'] = model.state_dict()
torch.save(data, 'checkpoints.pt')

print(torch.load('checkpoints.pt'))

But the problem is it requires loading the existing file in memory before appending a new checkpoint, which is memory intensive especially considering that I have 100s of checkpoints. Is there a way to do this (or something similar) without having to load the existing checkpoints back into memory?

Asked By: Thomas Wagenaar

||

Answers:

See this post on multiple pickled objects in the same file. The short of it is that pytorch checkpointing is backended by pickle, so if you use a trivial pickle wrapper rather than the default torch.save you can easily accomplish this:

import _pickle as pickle # _pickle is the newer updated version (cpickle) I believe, with improved C-backend

def append_save(network,path):
    with open(path,"ab") as f:
        pickle.dump(network.state_dict(),f)

Now, you’ll have to read each model state-dict serially from the file.

def read_checkpoints(path):
  checkpoints = []

  with open(path,"rb") as f:
      while True:
          try:
              checkpoints.append(pickle.load(f))
          except EOFError:
              break
      
Answered By: DerekG