Polynomial features using numpy or torch

Question:

Having tensor
[a, b]
I want to create a tensor of the form

[a, b, ab, a^2, b^2]

or even of higher order

[a, b, ab, a^2, b^2, (a^2)b, a(b^2), (a^2)(b^2), a^3, b^3]

I want to solve the issue in a short time. I can solve it with loops, but that’s not the way I really would like to do that. However dynamic programming works for me, so using 2nd order to compute 3rd order is fine.

The final solution will be implemented in PyTorch, but NumPy implementation would be useful, I can port it to PyTorch on my own.

Edit:

As you have asked, I’m posting my attempt, which I’m not very proud of:

def polynomial(t: torch.Tensor) -> torch.Tensor:
    r = t.clone()
    r_c = torch.empty((t.shape[0], math.comb(t.shape[1], 2) + t.shape[1]))
    i = 0
    for idx in range(t.shape[1]):
        for jdx in range(idx, t.shape[1]):
            r_c[:, i] = (r[:, idx].unsqueeze(-1) * r[:, jdx].unsqueeze(-1)).squeeze(-1)
            i += 1
    r = torch.hstack([r, r_c])
    return r

For

t = torch.tensor([
        [1, 2, 3],
        [3, 4, 5],
        [5, 6, 7]
    ])
polynomial(t)

results in

tensor([[ 1.,  2.,  3.,  1.,  2.,  3.,  4.,  6.,  9.],
        [ 3.,  4.,  5.,  9., 12., 15., 16., 20., 25.],
        [ 5.,  6.,  7., 25., 30., 35., 36., 42., 49.]])
Asked By: Jacek Karolczak

||

Answers:

For anyone who will meet the problem:

def polynomial(t: torch.Tensor, degree: int = 2, interaction_only: bool = False) -> torch.Tensor:
    cols = t.hsplit(t.shape[1])
    if interaction_only:
        degree = 2
        combs = combinations(cols, degree)
    else:
        combs = combinations_with_replacement(cols, degree)
    prods = [torch.prod(torch.hstack(comb), -1).unsqueeze(-1) for comb in combs]
    r = torch.hstack(prods)
    return torch.hstack((t, r)) if degree == 2 else torch.hstack((polynomial(t, degree - 1), r))
Answered By: Jacek Karolczak
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.