Delete a row by index from pytorch tensor

Question:

I have a pytorch tensor of size torch.Size([4, 3, 2])

tensor([[[0.4003, 0.2742],
     [0.9414, 0.1222],
     [0.9624, 0.3063]],

    [[0.9600, 0.5381],
     [0.5758, 0.8458],
     [0.6342, 0.5872]],

    [[0.5891, 0.9453],
     [0.8859, 0.6552],
     [0.5120, 0.5384]],

    [[0.3017, 0.9407],
     [0.4887, 0.8097],
     [0.9454, 0.6027]]])

I would like to delete the 2nd row so that the tensor becomes torch.Size([3, 3, 2])

tensor([[[0.4003, 0.2742],
     [0.9414, 0.1222],
     [0.9624, 0.3063]],

    [[0.5891, 0.9453],
     [0.8859, 0.6552],
     [0.5120, 0.5384]],

    [[0.3017, 0.9407],
     [0.4887, 0.8097],
     [0.9454, 0.6027]]])

How can I delete the nth row of the 3D tensor?

Asked By: Someone

||

Answers:

The operation below selects all but one "row":

import torch

torch.manual_seed(2021)

row = 2
x = torch.rand((4, 3, 2))

new_x = x[torch.arange(1, x.shape[0]+1) != row, ...]

print(new_x.shape)
# >>> torch.Size([3, 3, 2])

print(x)
# > tensor([[[0.1304, 0.5134],
# >          [0.7426, 0.7159],
# >          [0.5705, 0.1653]],
# > 
# >         [[0.0443, 0.9628],
# >          [0.2943, 0.0992],
# >          [0.8096, 0.0169]],
# > 
# >         [[0.8222, 0.1242],
# >          [0.7489, 0.3608],
# >          [0.5131, 0.2959]],
# > 
# >         [[0.7834, 0.7405],
# >          [0.8050, 0.3036],
# >          [0.9942, 0.5025]]])

print(new_x)
# > tensor([[[0.1304, 0.5134],
# >          [0.7426, 0.7159],
# >          [0.5705, 0.1653]],
# > 
# >         [[0.8222, 0.1242],
# >          [0.7489, 0.3608],
# >          [0.5131, 0.2959]],
# > 
# >         [[0.7834, 0.7405],
# >          [0.8050, 0.3036],
# >          [0.9942, 0.5025]]])
Answered By: Berriel
import torch
x = torch.randn(size=(4,3,2))

row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))

print(x.shape)
>>> torch.Size([3, 3, 2])
Answered By: Vinson Ciawandy

For now, I have this slow method (it works for me because I call this function infrequently).

def delete_row_tensor(a, del_row, device):
    n = a.cpu().detach().numpy()
    n = np.delete(n, del_row, 0)
    n = torch.from_numpy(n).to(device)
    return n

I am still looking for efficient torch methods.

Answered By: Someone
import torch
import numpy as np

x = torch.randn(size=(100, 200, 300)) 

index = np.array(range(x.size(0)))  # get dim 0 index

del_index = np.array([29, 31, 49])  # the raw you want to delete

new_index = np.delete(index,del_index,axis=0) # get new index
    
new_x=x[new_index,:,:]              # get new x which del row

# code as a function
def delete_tensor_row(x,delete_raw_list,dim)

    index = np.array(range(x.size(dim)))

    del_index = np.array(delete_raw_list)

    new_index = np.delete(index,del_index,axis=dim)
    
    if dim==0:
        new_x=x[new_index,:,:]
    if dim==1:
        new_x=x[:,new_index,:]
    if dim==2:
        new_x=x[:,:,new_index]

    return new_x

Answered By: Looka

@Vinson Ciawandy’s answer raises an error if the row (to-be-exluded) is at the start or the end. That’s why I wrote this:

import torch
x = torch.randn(size=(4,3,2))

row_exclude = 2
x_before_row = x[:row_exclude]
x_after_row = x[row_exclude+1:]
if x_before_row.numel() == 0:  # row was at the start
    x_without_row = x_after_row
elif x_after_row.numel() == 0:  # row was at the end
    x_without_row = x_before_row
else:
    x_without_row = torch.cat((x_before_row, x_after_row))

print(x.shape)
>>> torch.Size([3, 3, 2])
Answered By: EricT
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.