Slice a multidimensional pytorch tensor based on values in other tensors
Question:
I have 4 PyTorch tensors:
data
of shape (l, m, n)
a
of shape (k,)
and datatype long
b
of shape (k,)
and datatype long
c
of shape (k,)
and datatype long
I want to slice the tensor data
such that it picks the element addressed by a
in 0th
dimension. In the 1st
and 2nd
dimensions, I want to pick a patch of values based on the element addressed by b
and c
. Specifically, I want to pick 9 values – a 3x3
patch around the value addressed by b
. Thus my sliced tensor should have a shape (k, 3, 3)
.
MWE:
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
data1 = data[a, b-1:b+1, c-1:c+1] # gives error
>>> TypeError: only integer tensors of a single element can be converted to an index
Expected output
data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on
How can I do this without using for loop?
PS:
- I’ve padded
data
to make sure that the locations addressed by a,b,c
are within the limit.
- I don’t need gradients to flow through this operation. So, I can convert these to NumPy and slice if that is faster. But I would prefer a solution in PyTorch.
Answers:
I would first expand the indices and then add shifts to the repeated indices. Note that the shift for the row and column should be reversed. For example,
import torch
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
index1 = a.repeat_interleave(9) # kernel_size^2
index2 = b.repeat_interleave(9) # kernel_size^2
shift = torch.arange(-1, 2).repeat_interleave(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, -1, -1, 0, 0, 0, 1, 1, 1]
shifted_index2 = index2 + shift
index3 = c.repeat_interleave(9)
shift = torch.arange(-1, 2).repeat(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, 0, 1, -1, 0, 1, -1, 0, 1]
shifted_index3 = index3 + shift
# Use the indexing arrays to select the patches
data1 = data[index1, shifted_index2, shifted_index3].view(5, 3, 3)
print(data1[0])
print(data1[1])
print(data1[2])
The output:
tensor([[143, 144, 145],
[153, 154, 155],
[163, 164, 165]])
tensor([[52, 53, 54],
[62, 63, 64],
[72, 73, 74]])
tensor([[126, 127, 128],
[136, 137, 138],
[146, 147, 148]])
I was able to do it with slice
although there is a list comprehension at the end. However, it is a loop of only k elements.
import numpy as np
a = torch.IntTensor([1, 0, 1, 1, 0]).long()
b = torch.IntTensor([5, 6, 3, 4, 7]).long()
c = torch.IntTensor([4, 3, 7, 6, 5]).long()
data = torch.arange(200).reshape((2, 10, 10))
a = list(slice(val, val+1) for val in a)
b = list(slice(val-1, val+2) for val in b)
c = list(slice(val-1, val+2) for val in c)
data1 = [data[a_, b_, c_] for a_, b_, c_ in zip(a,b,c)]
I have 4 PyTorch tensors:
data
of shape(l, m, n)
a
of shape(k,)
and datatypelong
b
of shape(k,)
and datatypelong
c
of shape(k,)
and datatypelong
I want to slice the tensor data
such that it picks the element addressed by a
in 0th
dimension. In the 1st
and 2nd
dimensions, I want to pick a patch of values based on the element addressed by b
and c
. Specifically, I want to pick 9 values – a 3x3
patch around the value addressed by b
. Thus my sliced tensor should have a shape (k, 3, 3)
.
MWE:
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
data1 = data[a, b-1:b+1, c-1:c+1] # gives error
>>> TypeError: only integer tensors of a single element can be converted to an index
Expected output
data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on
How can I do this without using for loop?
PS:
- I’ve padded
data
to make sure that the locations addressed bya,b,c
are within the limit. - I don’t need gradients to flow through this operation. So, I can convert these to NumPy and slice if that is faster. But I would prefer a solution in PyTorch.
I would first expand the indices and then add shifts to the repeated indices. Note that the shift for the row and column should be reversed. For example,
import torch
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
index1 = a.repeat_interleave(9) # kernel_size^2
index2 = b.repeat_interleave(9) # kernel_size^2
shift = torch.arange(-1, 2).repeat_interleave(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, -1, -1, 0, 0, 0, 1, 1, 1]
shifted_index2 = index2 + shift
index3 = c.repeat_interleave(9)
shift = torch.arange(-1, 2).repeat(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, 0, 1, -1, 0, 1, -1, 0, 1]
shifted_index3 = index3 + shift
# Use the indexing arrays to select the patches
data1 = data[index1, shifted_index2, shifted_index3].view(5, 3, 3)
print(data1[0])
print(data1[1])
print(data1[2])
The output:
tensor([[143, 144, 145],
[153, 154, 155],
[163, 164, 165]])
tensor([[52, 53, 54],
[62, 63, 64],
[72, 73, 74]])
tensor([[126, 127, 128],
[136, 137, 138],
[146, 147, 148]])
I was able to do it with slice
although there is a list comprehension at the end. However, it is a loop of only k elements.
import numpy as np
a = torch.IntTensor([1, 0, 1, 1, 0]).long()
b = torch.IntTensor([5, 6, 3, 4, 7]).long()
c = torch.IntTensor([4, 3, 7, 6, 5]).long()
data = torch.arange(200).reshape((2, 10, 10))
a = list(slice(val, val+1) for val in a)
b = list(slice(val-1, val+2) for val in b)
c = list(slice(val-1, val+2) for val in c)
data1 = [data[a_, b_, c_] for a_, b_, c_ in zip(a,b,c)]