In numpy/torch – for vector v and another vector of indices we can reindex :
v[IX]
How to do the same when I have batch of vectors v and batch of indexes ?
I mean v – is 2d array of v[i,:] – i-th vector, it should be reindexed by IX[i,:].
Slow Python way is just:
for i in range(v.shape[0]):
new_v[i,:] = v[i,:][IX[i,:]]
But the question is to do it in numpy/torch way – without slow Python loops.
The idea comes to mind something like – v.ravel()[ (IX + range(v.shape[0) ).ravel() ].reshape(N,-1),
but may be there is more canonical/readable way ?