So I have m number of different vectors (say x), each one is (1,n), stacked horizontally, totally in a (m,n) matrix we call it B, and a matrix (A) with dimension (n,n).
I want to compute x@A^T@A@x^T@x@A@A^T@x^T for all vectors x, output should be (m,1)
How can I write an einsum
query for this given B and A?
Here is a sample without einsum
:
import torch
m = 30
n = 4
B = torch.randn(m, n)
A = torch.randn(n, n)
result = torch.zeros(m,1)
for i in range(m):
x = B[i].unsqueeze(0)
result[i] = torch.matmul(x,torch.matmul(A.T,torch.matmul(A,torch.matmul(x.T,torch.matmul(x,torch.matmul(A, torch.matmul(A.T, x.T)))))))
I could write a query for xAx^T but not for xA^TAx^TxBB^Tx^T. Here is for xAx^T:
torch.einsum('bi,ij,bj -> b',B,A,B)
Let write it down, using Einstein notation, the matrix-element (a,h) of the result is
(x @ A^T @ A @ x^T @ x @ A @ A^T @ (x^T)_(a,i) =
x_(a,b) (A^T)_(b,c) A_(c,d) (x^T)_(d,e) x_(e,f) A_(f,g) (A^T)_(g,h) (x^T)_(h,i) =
x_(a,b) A_(c,b) A_(c,d) x_(e,d) x_(e,f) A_(f,g) A_(h,g) x_(i,h) =
x_(1,b) A_(c,b) A_(c,d) x_(1,d) x_(1,f) A_(f,g) A_(h,g) x_(1,h) =
which is actually a scalar. If you want it for all the x, using matrix B
B_(a,b) A_(c,b) A_(c,d) B_(a,d) B_(a,f) A_(f,g) A_(h,g) B_(a,h)
so the argument will be
('ab,cb,cd,ad,af,fg,hg,ah->a', B, A, A, B, B, A, A, B)
Your problem, with modest dimensions:
In [16]: m = 30
...: n = 4
...: B = np.random.randn(m, n)
...: A = np.random.randn(n, n)
In [16]: result = np.zeros((m,1))
...: for i in range(m):
...: x = np.atleast_2d(B[i])
...: result[i] = [email protected]@[email protected]@x@[email protected]@x.T
In [18]: result[:5,0]
Out[18]: array([20.84203476, 86.48468007, 16.52952006, 14.27289909, 67.28749281])
And a timing (1ms for this size doesn’t look too bad):
In [19]: %%timeit
...: result = np.zeros((m,1))
...: for i in range(m):
...: x = np.atleast_2d(B[i])
...: result[i] = [email protected]@[email protected]@x@[email protected]@x.T
...:
1.33 ms ± 2.78 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Is the (m,1) shape important? Why not (m,)?
np.einsum
The proposed einsum solutions (which just differ in how the transpose is applied):
In [20]: result_2 = np.einsum('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T)
In [21]: result_2.shape
Out[21]: (30,)
In [22]: result_2[:5]
Out[22]: array([20.84203476, 86.48468007, 16.52952006, 14.27289909, 67.28749281])
In [23]: timeit result_2 = np.einsum('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T)
5.77 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
This is slower than the OP’s iteration. Iterating 30 times on a complex expression is no big deal. Often that kind of iteration is faster than some whole-array form.
We could explore the einsum
optimize options. For example:
In [24]: timeit result_2 = np.einsum('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T, optimize=True)
1.03 ms ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
einsum
breaks the complex calculation into a sequence of simple matmul/dot
calculations.
pure matmul
An alternative is use matmul
with the m
dimension as a ‘batch’. Let’s see if that works.
I explored this recently for a simpler calculation, ‘x@A@x^T’.
/a/78910239/901925
The common use of torch.unsqueeze
and (m,1)
shape suggests a related, or overlapping question authorship.
There I used:
B[:,None,:]@A@B[:,:,None] # (m,1,1)
to make B
(and B.T
) 3d with m
as the first (batch) dimension. Note the (m,1,1) shape. Think of that as m
(1,1) results. The extra dimensions are often squeeze
out. matmul
uses fast BLAS code to handle the 2d matrix produce, while iterating in compiled code over the leading ‘batch’ dimension. That’s similar to what the OP does, but without the slower python level iteration.
So defining 3d B
and B.T
as:
In [25]: B1 = B[:,None,:]; B2 = B[:,:,None]
we can perform the matmul
as before:
In [26]: result_3 = [email protected]@A@B2@B1@[email protected]@B2
In [27]: result_3.shape
Out[27]: (30, 1, 1)
In [28]: result_3[:5,0,0]
Out[28]: array([20.84203476, 86.48468007, 16.52952006, 14.27289909, 67.28749281])
In [29]: np.allclose(result[:,0],result_3[:,0,0])
Out[29]: True
This is quite a bit faster:
In [31]: %%timeit
...: B1 = B[:,None,:]; B2 = B[:,:,None]
...: result_3 = [email protected]@A@B2@B1@[email protected]@B2
...: result_3[:,0]
49.6 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
einsum_path
einsum_path
can be used to see how einsum
can speed up complicated calculations like this. I won’t go into the details, but this gives some idea of what it can do:
In [40]: print(np.einsum_path('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T)[1])
Complete contraction: ab,bc,cd,da,ae,ef,fg,ga->a
Naive scaling: 7
Optimized scaling: 3
Naive FLOP count: 9.830e+05
Optimized FLOP count: 2.777e+03
Theoretical speedup: 353.994
Largest intermediate: 1.200e+02 elements
--------------------------------------------------------------------------
scaling current remaining
--------------------------------------------------------------------------
3 cd,bc->bd ab,da,ae,ef,fg,ga,bd->a
3 fg,ef->eg ab,da,ae,ga,bd,eg->a
3 bd,ab->da da,ae,ga,eg,da->a
2 da,da->a ae,ga,eg,a->a
2 a,ae->ea ga,eg,ea->a
3 eg,ga->ea ea,ea->a
2 ea,ea->a a->a
Examining this path suggests that we can improve the long matmul sequence by precalculating terms like AA=A.T@A
and even B1@AA
.
2
You can use the following call:
result_2 = torch.einsum('ab,bc,cd,da,ae,ef,fg,ga->a', B, A.T, A, B.T, B, A, A.T, B.T)
The key thing is the index a
. That index indicates the row of B. From what you describe, you do not want to do matrix multiplication between B and A. For example, what you denote by “B @ A.T @ A @ B.T@” is a a vector of length m in your problem, while “@” was just matrix multiplication, the result would be a matrix of size (m,m). By using the same index a
every time we refer to the rows of B
, we have each row just be ‘einsumed’ with itself, rather than getting mixed with the other rows.
For example, consider the alternative
result_3 = torch.einsum('ab,bc,cd,de,ef,fg,gh,ha->a', B, A.T, A, B.T, B, A, A.T, B.T)
This alternative will just do the matrix multiplication, and return the diagonal elements of the resulting (m,m) matrix.
3