I have non-perfect, non-complete tree, which I would like to process bottom-up. I.e. level-t out put some values, which will be used in computations for level-(t+1), but level t+1 includes some fixed values also. Computations involve static shaped, but different size arrays.
I know how to process each level. Also for whole computations I know how to pre-compute all indexes that I will need during computations. However, I don’t understand how to organise writing values from output level to next level in jax. This operations are local, hence I don’t want to re-create the whole array each time. If outputs could be stored in single array instead of list, I would use in-place overwriting, so I can not.
I will discuss next in more details what is going on. I will use just integer numbers summation in order to emphasise only structure of computations, so I am interested in general non-commutative operations with different shaped tensors. Hence, some simplifications will not work.
Just python
def bin_(a, b):
return a + b
def process_level(do, pairs):
return [do(*pair) for pair in pairs]
def write_level_(what_write, where_write, seq):
for k, v in zip(where_write, what_write):
seq[k] = v
a, b, c, d, e, f = tuple(range(6))
seq = [None, None, e, None, f, None, a, b, c, d]
level_idx = [[6,7,8,9],[2,3,4,5],[0,1]]
level_p_idx = [[3,5],[0,1]]
for l, p in zip(level_idx, level_p_idx):
pairs = zip(l[::2],l[1::2])
l_out = process_level(bin_, pairs)
write_level_(l_out, p, seq)
So, write_level_ change things in place, hence is not acceptable. However, I can not figure out what to do. One can note, that problem is similar to jax.lax.scan. This is true, but not quite: scan assumes perfect tree. Any suggestions appreciated. I guess I just don’t know something basic about jax how to modify things, so I can not figure out by myself.