Slice 1D tensor in PyTorch with tensors of start and end indexes

Question:

I am trying to create a 2D tensor of even slices from a 1D tensor in PyTorch. Say we have a 1D data tensor and tensors of indexes as:

>>> data = torch.arange(10)
>>> data
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> starts = torch.tensor([0, 3, 4, 1])
>>> ends = starts + 2
>>> starts
tensor([0, 3, 4, 1])
>>> ends
tensor([2, 5, 6, 3])

How could I index the data tensor without looping over and slicing with each set of indexes to achieve a result as:


>>> dataSlices
tensor([[0, 1],
        [3, 4],
        [4, 5],
        [1, 2]])

My first obvious thought is to just put the starts and ends as you would with individual indexes but it just errors out:

>>> data[starts:ends]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: only integer tensors of a single element can be converted to an index

I’ve looked through some parts of the documentation but can’t seem to find a way, am I missing something obvious?

Asked By: EternalTrail

||

Answers:

If it were a list, zip would solve your problem

Looks like you need:
torch.transpose().

And use the solution from this answer by @bachr:
https://stackoverflow.com/a/60367265/3456886

Answered By: user3456886

EDIT:

Since then I found a pythonic way for the ranges, without the list comprehension! For this your ends should be bigger by one, as this method will take python ranges, wich does not contain the end of the range.

indices=torch.stack((starts,ends),axis=1)
newtensor=torch.stack([data[slice(idx[0], idx[1])] for idx in indices])

OLD ANSWER:

You can do this with torch.take. To get your desired output, you need to subtract 1 from your ends indices, as it takes exact indices, not intervals. (Alternatively you can generate ends like that in the first place)

indices=torch.stack((starts,ends-1),axis=1)
newtensor=torch.take(data,indices)

tensor([[0, 1],
        [3, 4],
        [4, 5],
        [1, 2]])

If you would want to take real intervals
(based on the fact that you named the indices starts and ends), this would be a solution for that:

indices=torch.stack((starts,ends),axis=1)
rangeindices=[torch.range(i[0],i[1]) for i in indices]
tensorindices=torch.stack(rangeindices).type(torch.LongTensor)
newtensor=torch.take(data,tensorindices)

tensor([[0, 1, 2],
        [3, 4, 5],
        [4, 5, 6],
        [1, 2, 3]])

But this would (understandably) result in a different tensor than your expected output.

Answered By: Franciska