Constrain elements in a PyTorch tensor to be equal

Question:

I have a PyTorch tensor and would like to impose equality constraints on its elements while optimizing. An example tensor of 2 * 9 is shown below, where the same color indicates the elements should always be equal.

Example tensor

Let’s make a minimal example of 1 * 4, and initialize the first two and last two elements to be equal respectively.

import torch
x1 = torch.tensor([1.2, 1.2, -0.3, -0.3], requires_grad=True)
print(x1)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

If I perform a simple least squares directly, the equality definitely exists no more.

y = torch.arange(4)
opt_1 = torch.optim.SGD([x1], lr=0.1)
opt_1.zero_grad()
loss = (y - x1).pow(2).sum()
loss.backward()
opt_1.step()
print(x1)
# tensor([0.9600, 1.1600, 0.1600, 0.3600], requires_grad=True)

I tried to express this tensor as a weighted sum of masks.

def weighted_sum(c, masks):
    return torch.sum(torch.stack([c[0] * masks[0], c[1] * masks[1]]), axis=0)

c = torch.tensor([1.2, -0.3], requires_grad=True)
masks = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1]])
x2 = weighted_sum(c, masks)
print(x2)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

In this way, the equality remains after optimization.

opt_c = torch.optim.SGD([c], lr=0.1)
opt_c.zero_grad()
y = torch.arange(4)
x2 = weighted_sum(c, masks)
loss = (y - x2).pow(2).sum()
loss.backward()
opt_c.step()
print(c)
# tensor([0.9200, 0.8200], requires_grad=True)
print(weighted_sum(c, masks))
# tensor([0.9200, 0.9200, 0.8200, 0.8200], grad_fn=<SumBackward1>)

However, the biggest issue of this solution is that I have to maintain a large set of masks when the input dimension is high; surely it will result in out of memory. Suppose the shape of input tensor is d_0 * d_1 * ... * d_m, and the number of equality blocks is k, then there will be a huge mask of shape k * d_0 * d_1 * ... * d_m, which is unacceptable.


Another solution might be upsampling the low resolution tensor like this one. However, it cannot be applied to irregular equality blocks, e.g.,

tensor([[ 1.2000,  1.2000,  1.2000, -3.1000, -3.1000],
        [-0.1000,  2.0000,  2.0000,  2.0000,  2.0000]])

So… is there a smarter way of implementing such equality constraints in a PyTorch tensor?

Asked By: Cheng

||

Answers:

if you want them to always be equal, why not just remove both the first and last value from x and y? The extra values can be derived from the model output when needed after training, since they’re expected to be equal to their neighbors anyway. There’s no need to learn two copies of the same values.

If you want a more approximate learning that they’re the same, you could add some_weight * (torch.abs(x[0]-x[1]) + torch.abs(x[-1] - x[-2])) to your loss function. Then your loss would be trying to learn that these are expected to be the same.

Or, instead of masks, if you have counts for each value, maybe you’re looking for something like this?

def convert(tensor, counts):
     return torch.cat( [v.repeat(count) for (v, count) in zip(tensor, counts) ] )

convert( torch.arange(4), [3,2,1,3])
tensor([0, 0, 0, 1, 1, 2, 3, 3, 3])
Answered By: nairbv

On top of @nairbv’s solution of repeating, I’ve also come up with another way to reduce the memory consumption of masking (i.e., the first tentative solution stated in my question).

The basic idea is using reduce to avoid unnecessary temporary tensors created during torch.stack.

from functools import reduce
from operator import mul

def weighted_sum_v2(c, masks):
    stacksum = lambda s, cm: torch.sum(torch.stack([s, cm]), dim=0)
    return reduce(stacksum, map(mul, c, masks))

The results are also as expected.

c = torch.tensor([1.2, -0.3], requires_grad=True)
masks = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1]])
x2 = weighted_sum_v2(c, masks)
print(x2)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

opt_c = torch.optim.SGD([c], lr=0.1)
opt_c.zero_grad()
y = torch.arange(4)
x2 = weighted_sum_v2(c, masks)
loss = (y - x2).pow(2).sum()
loss.backward()
opt_c.step()
print(c)
# tensor([0.9200, 0.8200], requires_grad=True)
print(weighted_sum_v2(c, masks))
# tensor([0.9200, 0.9200, 0.8200, 0.8200], grad_fn=<SumBackward1>)
Answered By: Cheng
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.