How to construct a network with two inputs in PyTorch
Question:
Suppose I want to have the general neural network architecture:
Input1 --> CNNLayer
---> FCLayer ---> Output
/
Input2 --> FCLayer
Input1 is image data, input2 is non-image data. I have implemented this architecture in Tensorflow.
All pytorch examples I have found are one input go through each layer. How can I define forward func to process 2 inputs separately then combine them in a middle layer?
Answers:
By “combine them” I assume you mean to concatenate the two inputs.
Assuming you concat along the second dimension:
import torch
from torch import nn
class TwoInputsNet(nn.Module):
def __init__(self):
super(TwoInputsNet, self).__init__()
self.conv = nn.Conv2d( ... ) # set up your layer here
self.fc1 = nn.Linear( ... ) # set up first FC layer
self.fc2 = nn.Linear( ... ) # set up the other FC layer
def forward(self, input1, input2):
c = self.conv(input1)
f = self.fc1(input2)
# now we can reshape `c` and `f` to 2D and concat them
combined = torch.cat((c.view(c.size(0), -1),
f.view(f.size(0), -1)), dim=1)
out = self.fc2(combined)
return out
Note that when you define the number of inputs to self.fc2
you need to take into account both out_channels
of self.conv
as well as the output spatial dimensions of c
.
Suppose I want to have the general neural network architecture:
Input1 --> CNNLayer
---> FCLayer ---> Output
/
Input2 --> FCLayer
Input1 is image data, input2 is non-image data. I have implemented this architecture in Tensorflow.
All pytorch examples I have found are one input go through each layer. How can I define forward func to process 2 inputs separately then combine them in a middle layer?
By “combine them” I assume you mean to concatenate the two inputs.
Assuming you concat along the second dimension:
import torch
from torch import nn
class TwoInputsNet(nn.Module):
def __init__(self):
super(TwoInputsNet, self).__init__()
self.conv = nn.Conv2d( ... ) # set up your layer here
self.fc1 = nn.Linear( ... ) # set up first FC layer
self.fc2 = nn.Linear( ... ) # set up the other FC layer
def forward(self, input1, input2):
c = self.conv(input1)
f = self.fc1(input2)
# now we can reshape `c` and `f` to 2D and concat them
combined = torch.cat((c.view(c.size(0), -1),
f.view(f.size(0), -1)), dim=1)
out = self.fc2(combined)
return out
Note that when you define the number of inputs to self.fc2
you need to take into account both out_channels
of self.conv
as well as the output spatial dimensions of c
.