I’m trying to use torch.autograd to train a simple recurrent neural network that predicts the next character in a sequence of characters that represent songs in an ABC notation.
The model looks like this:
model = keras.Sequential([
keras.layers.Input(shape=(SEQ_LENGTH,), batch_size=batch_size),
keras.layers.Embedding(len(vocabulary), 256),
keras.layers.LSTM(1024, return_sequences=True, stateful=False),
keras.layers.Dense(len(vocabulary))
])
The training process looks like this:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
for i in range(1000):
inputs, targets = random_inputs_and_targets(vectorized_songs, seq_length=SEQ_LENGTH, batch_size=BATCH_SIZE)
predictions = model(inputs)
loss = loss_fn(predictions.permute((0, 2, 1)), torch.from_numpy(targets).long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
I then save the model parameters and load them into the similar model, but with batch size 1
:
torch.save(model.state_dict(), os.path.join(cwd, "model.pt"))
print("The model has been saved")
trained_model = build_model(1)
trained_model.load_state_dict(torch.load(os.path.join(cwd, "model.pt")))
trained_model.eval()
Then, I use the loaded model to predict a string of characters that I expect to look like a song in the ABC notation:
input_eval = [char_to_index[s] for s in start_string]
input_eval = torch.unsqueeze(torch.tensor(input_eval), 0)
text_generated = []
for i in range(generation_length):
predictions = torch.squeeze(model(input_eval), 0)
predicted_index = torch.multinomial(softmax(predictions, dim=0), 1, replacement=True)[-1, 0]
input_eval = torch.unsqueeze(torch.unsqueeze(predicted_index, 0), 0)
text_generated.append(index_to_char[predicted_index.item()])
return start_string + ''.join(text_generated)
The full code is here.
During the 1000 training epochs, the loss function value goes down from around 4.42
to 0.78
, as expected.
But when I then try to use the “trained” model to generate a song, the result looks like a random string: XwQ5>ab>6q6S(z']!<hxaG4..M= (=ERp/xJmS|qIh_CzbM0D-N 6Yc=Ei[tcodBsEKfW<WZ5Jb("u1rrGLcFIk"PVk.'FEII:(qu7.nFbw^3/RY2LyrW
. An example of the full result can be seen here.
How do I even start debugging what is going wrong? Previously I built a simple non-recurrent classifier using torch.autograd
, its outputs were only 90% accurate, but this was still much better than when I try to build an RNN. Can it be that the hidden state that the RNN needs to predict the next character is lost somewhere during training or actual prediction?
Any suggestions are welcome, since I’m getting stuck.