How to convert custom Pytorch model to torchscript (pth to pt model)?

Question:

I trained a custom model with PyTorch using colab environment. I successfully saved the trained model to Google Drive with the name model_final.pth. I want to convert model_final.pth to model_final.pt so that it can be used on mobile devices.

The code I use to train the model is as follows:

from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("mouse_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") 
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025 
cfg.SOLVER.MAX_ITER = 1000   
cfg.SOLVER.STEPS = []        
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  
cfg.OUTPUT_DIR="drive/Detectron2/"

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

The code I used to convert the model is as follows:

from detectron2.modeling import build_model
import torch
import torchvision

print("cfg.MODEL.WEIGHTS: ",cfg.MODEL.WEIGHTS)   ## RETURNS : cfg.MODEL.WEIGHTS:  drive/Detectron2/model_final.pth
model = build_model(cfg)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("drive/Detectron2/model-final.pt")

But I am getting this error IndexError: too many indices for tensor of dimension 3 :

cfg.MODEL.WEIGHTS:  drive/Detectron2/model_final.pth
/usr/local/lib/python3.6/dist-packages/torch/tensor.py:593: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  'incorrect results).', category=RuntimeWarning)
---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)
<ipython-input-17-8e544c0f39c8> in <module>()
      7 model.eval()
      8 example = torch.rand(1, 3, 224, 224)
----> 9 traced_script_module = torch.jit.trace(model, example)
     10 traced_script_module.save("drive/Detectron2/model_final.pt")

7 frames
/usr/local/lib/python3.6/dist-packages/detectron2/modeling/meta_arch/rcnn.py in <listcomp>(.0)
    219         Normalize, pad and batch the input images.
    220         """
--> 221         images = [x["image"].to(self.device) for x in batched_inputs]
    222         images = [(x - self.pixel_mean) / self.pixel_std for x in images]
    223         images = ImageList.from_tensors(images, self.backbone.size_divisibility)

IndexError: too many indices for tensor of dimension 3
Asked By: Murat Öter

||

Answers:

Detectron2 models expect a dictionary or a list of dictionaries as input by default.

So you can not directly use torch.jit.trace function. But they provide a wrapper, called TracingAdapter, that allows models to take a tensor or a tuple of tensors as input. You can find out how to use it in their torchscript tests.

The code for tracing your Mask RCNN model could be (I did not try it):

import torch
import torchvision
from detectron2.export.flatten import TracingAdapter

def inference_func(model, image):
    inputs = [{"image": image}]
    return model.inference(inputs, do_postprocess=False)[0]

print("cfg.MODEL.WEIGHTS: ",cfg.MODEL.WEIGHTS)   ## RETURNS : cfg.MODEL.WEIGHTS:  drive/Detectron2/model_final.pth
model = build_model(cfg)
example = torch.rand(1, 3, 224, 224)
wrapper = TracingAdapter(model, example, inference_func)
wrapper.eval()
traced_script_module = torch.jit.trace(wrapper, (example,))
traced_script_module.save("drive/Detectron2/model-final.pt")

More info on detectron2 deployment with tracing can be found here.

Answered By: Rémi Chauvenne

This example can help. This is approach like bottom method. But it with .pth using obviously.

import torch
import torchvision
from unet import UNet 

model = UNet(3, 2)
model.load_state_dict(torch.load("best_weights.pth"))
model.eval()
example = torch.rand(1, 3, 320, 480)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

Code from this site.

Answered By: Alex Titov