Multiplying matrices is fun:
import torch
# dimensions and stuff
batch_size = 1
input_size = 8
layer_1_emb_size = 3
layer_2_emb_size = 4
layer_3_emb_size = 5
layer_4_emb_size = 6
# weight matrices
torch.manual_seed(42) # Reproducibility = ♡
l1_weights = torch.rand(input_size, layer_1_emb_size)
l2_weights = torch.rand(layer_1_emb_size, layer_2_emb_size)
l3_weights = torch.rand(layer_2_emb_size, layer_3_emb_size)
l4_weights = torch.rand(layer_3_emb_size, layer_4_emb_size)
# forward pass
some_input = torch.rand(batch_size, input_size)
out = torch.mm(some_input, l1_weights)
out_again = torch.mm(out, l2_weights)
very_out_again = torch.mm(out_again, l3_weights)
last_very_out_again = torch.mm(very_out_again, l4_weights)
In my use-case, I have a bunch of masking tensors (tensors with values of 0
or 1
, where 0
is “mask”) that I need to use to decide which elements of the input matrices to multiply with the next layer’s weight matrix:
# forward pass
some_input = torch.rand(batch_size, input_size)
random_masking_tensor = (torch.rand((batch_size, input_size)) < 0.5).float()
some_input = some_input * random_masking_tensor # element-wise masking
out = torch.mm(some_input, l1_weights)
print(random_masking_tensor)
print(some_input)
print(l1_weights)
print(out)
>>>
tensor([[1., 1., 0., 1., 1., 1., 1., 0.]])
tensor([[0.7860, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.7980, 0.0000]])
tensor([[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.2566, 0.7936, 0.9408],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696],
[0.4414, 0.2969, 0.8317]])
tensor([[1.7867, 2.5470, 1.6331]])
However, I do not wish to just multiply all values (e.g., 0
s as well). Instead, I want to
- Exclude the corresponding masked columns/rows based on the masked indices.
- Do this intelligently (e.g., no
for-loops
please) - And most importantly maintain the gradient, so if I call
optimizer.backward()
the relevant elements should be updated (i.e., not the masked rows/columns). And yes, I know that if I multiply them by0
their gradient will already be0
and hence they will not be updated, but I need to exclude them entirely.
So for example, in the above example the masked indices of random_masking_tensor
are [2, 7]
, which means that
- I want to exclude the
[2, 7]
columns fromsome_input
(e.g., the ones that are0
s after the element-wise multiplication — though I want to not have to multiply it by the masked tensor at all) - I want to exclude the
[2, 7]
rows in the weight matrix from the computation (i.e.,[0.2566, 0.7936, 0.9408]
and[0.4414, 0.2969, 0.8317]
)
So the matrix multiplication should instead of being
print(torch.mm(
torch.tensor([[0.7860, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.7980, 0.0000]]),
torch.tensor([[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.2566, 0.7936, 0.9408],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696],
[0.4414, 0.2969, 0.8317]])
))
>>> tensor([[1.7866, 2.5468, 1.6330]])
be
print(torch.mm(
torch.tensor([[0.7860, 0.1115, 0.6524, 0.6057, 0.3725, 0.7980]]),
torch.tensor([[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696]])
))
>>> tensor([[1.7866, 2.5468, 1.6330]])
Note that my matrices are quite large (e.g., emb_size=50000
) and the batch size is not 1 (only used these values for this example), which is why I’m not sure how to do it (with batch size of 1 I can probably just use torch.index_select
or something similar)