Delete a row by index from pytorch tensor
Question:
I have a pytorch tensor of size torch.Size([4, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.9600, 0.5381],
[0.5758, 0.8458],
[0.6342, 0.5872]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
I would like to delete the 2nd row so that the tensor becomes torch.Size([3, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
How can I delete the nth row of the 3D tensor?
Answers:
The operation below selects all but one "row":
import torch
torch.manual_seed(2021)
row = 2
x = torch.rand((4, 3, 2))
new_x = x[torch.arange(1, x.shape[0]+1) != row, ...]
print(new_x.shape)
# >>> torch.Size([3, 3, 2])
print(x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.0443, 0.9628],
# > [0.2943, 0.0992],
# > [0.8096, 0.0169]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
print(new_x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))
print(x.shape)
>>> torch.Size([3, 3, 2])
For now, I have this slow method (it works for me because I call this function infrequently).
def delete_row_tensor(a, del_row, device):
n = a.cpu().detach().numpy()
n = np.delete(n, del_row, 0)
n = torch.from_numpy(n).to(device)
return n
I am still looking for efficient torch methods.
import torch
import numpy as np
x = torch.randn(size=(100, 200, 300))
index = np.array(range(x.size(0))) # get dim 0 index
del_index = np.array([29, 31, 49]) # the raw you want to delete
new_index = np.delete(index,del_index,axis=0) # get new index
new_x=x[new_index,:,:] # get new x which del row
# code as a function
def delete_tensor_row(x,delete_raw_list,dim)
index = np.array(range(x.size(dim)))
del_index = np.array(delete_raw_list)
new_index = np.delete(index,del_index,axis=dim)
if dim==0:
new_x=x[new_index,:,:]
if dim==1:
new_x=x[:,new_index,:]
if dim==2:
new_x=x[:,:,new_index]
return new_x
@Vinson Ciawandy’s answer raises an error if the row (to-be-exluded) is at the start or the end. That’s why I wrote this:
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x_before_row = x[:row_exclude]
x_after_row = x[row_exclude+1:]
if x_before_row.numel() == 0: # row was at the start
x_without_row = x_after_row
elif x_after_row.numel() == 0: # row was at the end
x_without_row = x_before_row
else:
x_without_row = torch.cat((x_before_row, x_after_row))
print(x.shape)
>>> torch.Size([3, 3, 2])
I have a pytorch tensor of size torch.Size([4, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.9600, 0.5381],
[0.5758, 0.8458],
[0.6342, 0.5872]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
I would like to delete the 2nd row so that the tensor becomes torch.Size([3, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
How can I delete the nth row of the 3D tensor?
The operation below selects all but one "row":
import torch
torch.manual_seed(2021)
row = 2
x = torch.rand((4, 3, 2))
new_x = x[torch.arange(1, x.shape[0]+1) != row, ...]
print(new_x.shape)
# >>> torch.Size([3, 3, 2])
print(x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.0443, 0.9628],
# > [0.2943, 0.0992],
# > [0.8096, 0.0169]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
print(new_x)
# > tensor([[[0.1304, 0.5134],
# > [0.7426, 0.7159],
# > [0.5705, 0.1653]],
# >
# > [[0.8222, 0.1242],
# > [0.7489, 0.3608],
# > [0.5131, 0.2959]],
# >
# > [[0.7834, 0.7405],
# > [0.8050, 0.3036],
# > [0.9942, 0.5025]]])
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))
print(x.shape)
>>> torch.Size([3, 3, 2])
For now, I have this slow method (it works for me because I call this function infrequently).
def delete_row_tensor(a, del_row, device):
n = a.cpu().detach().numpy()
n = np.delete(n, del_row, 0)
n = torch.from_numpy(n).to(device)
return n
I am still looking for efficient torch methods.
import torch
import numpy as np
x = torch.randn(size=(100, 200, 300))
index = np.array(range(x.size(0))) # get dim 0 index
del_index = np.array([29, 31, 49]) # the raw you want to delete
new_index = np.delete(index,del_index,axis=0) # get new index
new_x=x[new_index,:,:] # get new x which del row
# code as a function
def delete_tensor_row(x,delete_raw_list,dim)
index = np.array(range(x.size(dim)))
del_index = np.array(delete_raw_list)
new_index = np.delete(index,del_index,axis=dim)
if dim==0:
new_x=x[new_index,:,:]
if dim==1:
new_x=x[:,new_index,:]
if dim==2:
new_x=x[:,:,new_index]
return new_x
@Vinson Ciawandy’s answer raises an error if the row (to-be-exluded) is at the start or the end. That’s why I wrote this:
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x_before_row = x[:row_exclude]
x_after_row = x[row_exclude+1:]
if x_before_row.numel() == 0: # row was at the start
x_without_row = x_after_row
elif x_after_row.numel() == 0: # row was at the end
x_without_row = x_before_row
else:
x_without_row = torch.cat((x_before_row, x_after_row))
print(x.shape)
>>> torch.Size([3, 3, 2])