My data comprise of 6 features coming from sensors. I am training an LSTM network on this data to predict three values.
During training, my training loss was consistently decreasing with each epoch, but test loss did not decrease much after couple of epochs.
This was the case when there was no overlap between training and test data. So I tried using subset of training data as test data.
But, still the same behavior, the test loss was still not decreasing.
Below is the code for LSTM model and trainer.
class LSTMModel(nn.Module):
def __init__(self, in_dim=6, hidden_size=200, num_layers=1, output_size=3):
super(LSTMModel, self).__init__()
self.lstm_1 = nn.LSTM(in_dim, hidden_size, num_layers, batch_first=True) #, return_sequences=True)
self.lstm_2 = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
self.lstm_3 = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
self.lstm_4 = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
x, _ = self.lstm_1(x)
x, _ = self.lstm_2(x)
x, _ = self.lstm_3(x)
x, _ = self.lstm_4(x)
output = self.fc(x[:, -1, :])
return output
class SimpleModelTrainer:
def __init__(self, model, train_dataset, test_dataset, batch_size=1024, epochs=100, lr=0.005): # window_size=200, do_windowing=True, patience=5, pad_testing_data = False
self.model = model
self.optimizer = AdamW(params=self.model.parameters(), lr=lr)
self.lr = lr
self.epochs = epochs
# self.patience = patience
self.batch_size = batch_size
# self.do_windowing = do_windowing
# self.window_size = window_size
self.loss_fn = nn.L1Loss()
# self.stop_early = False
self.train_data = train_dataset
self.test_data = test_dataset
def train(self):
self.train_dataloader = torch.utils.data.DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, generator=torch.Generator(device=device))
self.test_dataloader = torch.utils.data.DataLoader(self.test_data, batch_size=self.batch_size, shuffle=True, generator=torch.Generator(device=device))
total_samples = 0
for epoch in tqdm(range(self.epochs), desc="epoch"):
self.model.train()
total_loss = 0
for train_data in tqdm(self.train_dataloader, desc="train"):
X = train_data[0]
Y = train_data[1]
if X.shape[0] != self.batch_size: continue # to avoid RuntimeError: shape '[16, 1, 256]' is invalid for input of size 3328
total_samples += self.batch_size
y_hat = self.model(X)
loss = self.loss_fn(y_hat, Y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / total_samples
val_loss = self.test(self.test_dataloader)
print(f"Epoch {epoch} - Train loss:{avg_train_loss:.10f}, Val loss:{val_loss:.10f}")
def test(self, dataloader):
self.model.eval()
with torch.no_grad():
total_loss = 0
total_samples = 0
for test_data in tqdm(dataloader, desc="test"):
X = test_data[0]
Y = test_data[1]
if X.shape[0] != self.batch_size: continue # to avoid RuntimeError: shape '[Y, 200, 6]' is invalid for input of size Z
total_samples += self.batch_size
y_hat = self.model(X)
loss = self.loss_fn(y_hat, Y)
total_loss += loss.item()
val_loss = total_loss/total_samples
return val_loss
I tried this with randomly generated dummy dataset. It gave exactly the same behavior as above!
You can check it in this colab notebook.
As you can see in the notebook, the validation loss is stuck at 0.00048 since first epoch. But training loss consistently decreases with each epoch from 0.00048 to 0.000016 in 28 th epoch.
(Its still training while I am writing this question.) The test dataset is the subset of training dataset:
train_dataset = CustomDataset(windowed_input_data, windowed_target_data)
test_dataset = CustomDataset(windowed_input_data[:20000], windowed_target_data[:20000])
Hence, I believe I should get similar behavior with the validation loss and validation loss should also reach approx 0.00001. I guess I have done some stupid mistake with the code (wrong pytorch API call?) and my eyes are simply not ready to help me out. Can someone help me out please? Did I miss something conceptually?