How to apply function element wise to 2D tensor
Question:
Very simple question but I have been struggling with this forever now.
import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
I want:
torch.tensor([[True,False],[False,True]])
Both the tensor and overlap are very big, so efficiency is wished here.
Answers:
The native way to do this is using torch.Tensor.apply_
method:
t.apply_(f)
However according to official doc it only works for tensors on CPU and discouraged for reaching high performance.
Besides it seems that there is not native torch function indicating if values of tensors are in a list and the only option should be to iterate over the list overlap
. See here and here. Thus you can try:
sum(t==i for i in overlap).bool()
I found that the second function is more performant for big t
and overlap
and the first one for small t
and overlap
.
I found an easy way. Since torch is implemented through numpy array the following works and is performant:
import torch
import numpy as np
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
mask = np.vectorize(f)(t)
Found here.
Very simple question but I have been struggling with this forever now.
import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
I want:
torch.tensor([[True,False],[False,True]])
Both the tensor and overlap are very big, so efficiency is wished here.
The native way to do this is using torch.Tensor.apply_
method:
t.apply_(f)
However according to official doc it only works for tensors on CPU and discouraged for reaching high performance.
Besides it seems that there is not native torch function indicating if values of tensors are in a list and the only option should be to iterate over the list overlap
. See here and here. Thus you can try:
sum(t==i for i in overlap).bool()
I found that the second function is more performant for big t
and overlap
and the first one for small t
and overlap
.
I found an easy way. Since torch is implemented through numpy array the following works and is performant:
import torch
import numpy as np
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
mask = np.vectorize(f)(t)
Found here.