how the code '-input[range(target.shape[0]),target]' works?

Question:

I’m learing pytorch.Reading the official tutorial,I met the preplexing code.
input is a tensor, so is target.

def nll(input,target):
    return -input[range(target.shape[0]),target].mean()

And the pred is:
pred

target is:
target

the ‘-input[range(target.shape[0]),target]’ is:
'input[range(target.shape[0],target]'

Output shows this is not substracting target from input or merging two tensors

Asked By: Yan W

||

Answers:

The code input[range(target.shape[0]), target] simply picks, from each row i of input the element at column indicated by the corresponding element of target, that is target[i].
In other words, if out = input[range(target.shape[0]), target] then out[i] = input[i, target[i]].

This is very similar to torch.gather.

Answered By: Shai
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.