How do I load a local model with torch.hub.load?
Question:
I need to avoid downloading the model from the web (due to restrictions on the machine installed).
This works, but it downloads the model from the Internet
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
I have placed the .pth
file and the hubconf.py
file in the /tmp/ folder and changed my code to
model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')
but to my surprise, it still downloads the model from the Internet. What am I doing wrong? How can I load the model locally?
Just to give you a bit more details, I’m doing all this in a Docker container that has a read-only volume at runtime, so that’s why the download of new files fails.
Answers:
There are two approaches you can take to get a shippable model on a machine without an Internet connection.
-
Load DeepLab with a pretrained model on a normal machine, use a JIT compiler to export it as a graph, and put it into the machine. The Script is easy to follow:
# To export
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
traced_graph.save('DeepLab.pth')
# To load
model = torch.jit.load('DeepLab.pth').eval().to(device)
In this case, the weights and network structure is saved as computational graph, so you won’t need any extra files.
-
Take a look at torchvision’s GitHub repository.
There’s a download URL for DeepLabV3 with Resnet101 backbone weights.
You can download those weights once, and then use deeplab from torchvision with pretrained=False flag and load weights manually.
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
model.load_state_dict(torch.load('downloaded weights path'))
Take in consideration, there might be a [‘state_dict’] or some similar parent key in state dict, where you would use:
model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)
This worked for me. Default source is github.
model = torch.hub.load(‘path/to/yolov5’, ‘custom’, path=’path/to/best.pt’, source=’local’) # local repo
‘path/to/yolov5’ where could find hubconf.py
I need to avoid downloading the model from the web (due to restrictions on the machine installed).
This works, but it downloads the model from the Internet
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
I have placed the .pth
file and the hubconf.py
file in the /tmp/ folder and changed my code to
model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')
but to my surprise, it still downloads the model from the Internet. What am I doing wrong? How can I load the model locally?
Just to give you a bit more details, I’m doing all this in a Docker container that has a read-only volume at runtime, so that’s why the download of new files fails.
There are two approaches you can take to get a shippable model on a machine without an Internet connection.
-
Load DeepLab with a pretrained model on a normal machine, use a JIT compiler to export it as a graph, and put it into the machine. The Script is easy to follow:
# To export model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval() traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W)) traced_graph.save('DeepLab.pth') # To load model = torch.jit.load('DeepLab.pth').eval().to(device)
In this case, the weights and network structure is saved as computational graph, so you won’t need any extra files.
-
Take a look at torchvision’s GitHub repository.
There’s a download URL for DeepLabV3 with Resnet101 backbone weights.
You can download those weights once, and then use deeplab from torchvision with pretrained=False flag and load weights manually.
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False) model.load_state_dict(torch.load('downloaded weights path'))
Take in consideration, there might be a [‘state_dict’] or some similar parent key in state dict, where you would use:
model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)
This worked for me. Default source is github.
model = torch.hub.load(‘path/to/yolov5’, ‘custom’, path=’path/to/best.pt’, source=’local’) # local repo
‘path/to/yolov5’ where could find hubconf.py