I have a source tensor of size (3, 2) and an index tensor of size (3, 3) containing integer values 0, 1, or 2. In pytorch I can do tensor indexing source[index]
to get a tensor of size (3, 3, 2). Example:
source:
tensor([[1, 6],
[2, 3],
[8, 0]])
index:
tensor([[2, 1, 2],
[1, 1, 2],
[2, 0, 0]])
source[index]:
tensor([[[8, 0],
[2, 3],
[8, 0]],
[[2, 3],
[2, 3],
[8, 0]],
[[8, 0],
[1, 6],
[1, 6]]])
I want to do the above operation but batched.
For example with a batch size of 2:
source shape –> (2, 3, 2)
index shape –> (2, 3, 3)
batched source[index]
shape –> (2, 3, 3, 2)
I can easily do this with a loop but I want to know if it can be done efficiently with torch.gather or some other built-in?
Pranav Jadhav is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.