I’m trying to implement a Zero-Layer Transformer as described in this article or Video. I’ve come up with the following implementation:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroLayerTransformers(nn.Module):
def __init__(self):
super().__init__()
self.W_e = nn.Embedding(vocab_size, embedding_size)
self.W_U = nn.Embedding(vocab_size, embedding_size)
def forward(self, tokens, targets=None):
x = self.W_e(tokens)
logits = torch.matmul(x, self.W_U.weight.T)
if targets is None:
loss = None
else:
batch_size, time, _ = logits.shape
logits = logits.view(batch_size * time, vocab_size)
targets = targets.view(batch_size * time)
loss = F.cross_entropy(logits, targets)
return logits, loss
My questions are:
Is this a correct implementation of a Zero-Layer Transformer?
Any insights or suggestions would be greatly appreciated!
I implemented the ZeroLayerTransformer as shown and trained it on a small text corpus. I expected:
The product of W_e and W_U to approximate bi-gram log probabilities.
Visualizations of W_e * W_U to show clear word association patterns.
Performance similar to a simple n-gram model on next-word prediction.
The model trains (loss decreases), but I’m unsure if it’s truly capturing bi-gram statistics as intended. I’m also uncertain if my implementation accurately represents the concept described in the article.
J.J.H is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.