Slice a multidimensional pytorch tensor based on values in other tensors

Question:

I have 4 PyTorch tensors:

  • data of shape (l, m, n)
  • a of shape (k,) and datatype long
  • b of shape (k,) and datatype long
  • c of shape (k,) and datatype long

I want to slice the tensor data such that it picks the element addressed by a in 0th dimension. In the 1st and 2nd dimensions, I want to pick a patch of values based on the element addressed by b and c. Specifically, I want to pick 9 values – a 3x3 patch around the value addressed by b. Thus my sliced tensor should have a shape (k, 3, 3).

MWE:

data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()

data1 = data[a, b-1:b+1, c-1:c+1]  # gives error

>>> TypeError: only integer tensors of a single element can be converted to an index

Expected output

data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on

How can I do this without using for loop?

PS:

  • I’ve padded data to make sure that the locations addressed by a,b,c are within the limit.
  • I don’t need gradients to flow through this operation. So, I can convert these to NumPy and slice if that is faster. But I would prefer a solution in PyTorch.
Asked By: Nagabhushan S N

||

Answers:

I would first expand the indices and then add shifts to the repeated indices. Note that the shift for the row and column should be reversed. For example,

import torch

data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()

index1 = a.repeat_interleave(9) # kernel_size^2

index2 = b.repeat_interleave(9) # kernel_size^2
shift = torch.arange(-1, 2).repeat_interleave(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, -1, -1,  0,  0,  0,  1,  1,  1]
shifted_index2 = index2 + shift

index3 = c.repeat_interleave(9) 
shift = torch.arange(-1, 2).repeat(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1,  0,  1, -1,  0,  1, -1,  0,  1]
shifted_index3 = index3 + shift

# Use the indexing arrays to select the patches
data1 = data[index1, shifted_index2, shifted_index3].view(5, 3, 3)

print(data1[0])
print(data1[1])
print(data1[2])

The output:

tensor([[143, 144, 145],
        [153, 154, 155],
        [163, 164, 165]])
tensor([[52, 53, 54],
        [62, 63, 64],
        [72, 73, 74]])
tensor([[126, 127, 128],
        [136, 137, 138],
        [146, 147, 148]])
Answered By: hpwww

I was able to do it with slice although there is a list comprehension at the end. However, it is a loop of only k elements.

import numpy as np 

a = torch.IntTensor([1, 0, 1, 1, 0]).long()
b = torch.IntTensor([5, 6, 3, 4, 7]).long()
c = torch.IntTensor([4, 3, 7, 6, 5]).long()
data = torch.arange(200).reshape((2, 10, 10))

a = list(slice(val, val+1)   for val in a)
b = list(slice(val-1, val+2) for val in  b)
c = list(slice(val-1, val+2) for val in  c)

data1 = [data[a_, b_, c_] for a_, b_, c_ in zip(a,b,c)]
Answered By: Brener Ramos