How to input a numpy array to a neural network in pytorch?
Question:
This is the neural network that I defined
class generator(nn.Module):
def __init__(self, n_dim, io_dim):
super().__init__()
self.gen = nn.Sequential(
nn.Linear(n_dim,64),
nn.LeakyReLU(.01),
nn.Linear(64, io_dim),
)
def forward(self, x):
return self.gen(x)
#The input x is:
x = numpy.random.dirichlet([10,6,3],3)
Now I want the neural network to take dirichlet distributed samples (sampled using numpy.random.dirichlet([10,6,3],10) ) as an input. How to do that?
Answers:
You need to convert numpy.array
to torch.Tensor
:
input_tensor = torch.from_numpy(x)
To input a NumPy array to a neural network in PyTorch, you need to convert numpy.array
to torch.Tensor
. To do that you need to type the following code.
input_tensor = torch.from_numpy(x)
After this, your numpy.array
is converted to torch.Tensor
.
Instead of using numpy to sample from a dirichlet distribution, use pytorch. Here is the code:
y = torch.Tensor([[10,6,3]])
m = torch.distributions.dirichlet.Dirichlet(y)
z=m.sample()
gen = generator(3,3)
gen(z)
This is the neural network that I defined
class generator(nn.Module):
def __init__(self, n_dim, io_dim):
super().__init__()
self.gen = nn.Sequential(
nn.Linear(n_dim,64),
nn.LeakyReLU(.01),
nn.Linear(64, io_dim),
)
def forward(self, x):
return self.gen(x)
#The input x is:
x = numpy.random.dirichlet([10,6,3],3)
Now I want the neural network to take dirichlet distributed samples (sampled using numpy.random.dirichlet([10,6,3],10) ) as an input. How to do that?
You need to convert numpy.array
to torch.Tensor
:
input_tensor = torch.from_numpy(x)
To input a NumPy array to a neural network in PyTorch, you need to convert numpy.array
to torch.Tensor
. To do that you need to type the following code.
input_tensor = torch.from_numpy(x)
After this, your numpy.array
is converted to torch.Tensor
.
Instead of using numpy to sample from a dirichlet distribution, use pytorch. Here is the code:
y = torch.Tensor([[10,6,3]])
m = torch.distributions.dirichlet.Dirichlet(y)
z=m.sample()
gen = generator(3,3)
gen(z)