I have been using this repo for inference for a while for a research project, but now I need to finetune the model using our own data. My dataframe is a csv file with the columns ["index","audio_path", "start_time", "end_time", "label"]
and my current script to load the data and the models is as follows:
import librosa
import pandas as pd
import numpy as np
import os
import dataset_utils, audio_utils, data_loaders, torch_utils
import torch
from models import ResNetBigger
df = pd.read_csv("./dummy_annotations/dummy.csv")
dataset = data_loaders.SwitchBoardLaughterDataset(df, './dummy_data', partial(audio_utils.featurize_melspec, hop_length=186),8000,1)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
# Load the checkpoint
checkpoint = torch.load("/home/brooklyn.sheppard1/laughter-detection/checkpoints/comparisons/resnet_with_augmentation_trained_on_audioset/best.pth.tar")
# Initialize the model
model = ResNetBigger()
print(checkpoint.keys())
# Load the state dict
model.load_state_dict(checkpoint['state_dict'])
# Load optimizer state
optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optim_dict'])
# Set the model to training mode
model.train()
# Define loss function
criterion = torch.nn.CrossEntropyLoss()
num_epochs=1
# Training loop
for epoch in range(num_epochs):
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'path/to/save/finetuned_model.pth')
I am having a hard time figuring out the correct data format for the SwitchBoardLaughterDataset dataloader. My current script is giving the following error:
Traceback (most recent call last):
File "train_from_checkpoint.py", line 88, in <module>
for inputs, labels in dataloader:
File "/home/brooklyn.sheppard1/.conda/envs/laughter/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 346, in __next__
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/brooklyn.sheppard1/.conda/envs/laughter/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/brooklyn.sheppard1/.conda/envs/laughter/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/brooklyn.sheppard1/laughter-detection/data_loaders.py", line 168, in __getitem__
audio_file = self.audios_hash[self.df.audio_path[index]]
TypeError: string indices must be integers
This is my first time trying to load a model like this from scratch for finetuning, so any advice/ideas are greatly appreciated!