I have created some keys that correspond to particular values inside the dataset, and I have implemented some functions in order to utilize errors or non-existent values.
class SingleImageDataset(torch.utils.data.Dataset):
'''
Returns single images from the dataset (both pre- and post-images).
'''
def __init__(self, mode, configs):
self.mode = mode
self.configs = configs
#self.only_positives = configs['datasets']['only_positives']
self.augmentation = configs['datasets']['augmentation']
tmp = configs['dataset_type'].split('_')
# format: "sen2_xx_mod_yy"
source_types = [tmp[0], tmp[2]]
sgd = {tmp[0]: tmp[1], tmp[2]: tmp[3]}
self.source_types = source_types
self.sgd = sgd
candidate_paths = [i for i in Path(configs['paths']['dataset']).glob('*') if i.name == configs['dataset_type']]
self.ds_path = candidate_paths[0]
# Read the pickle files containing information on the splits
patches = pickle.load(open(self.ds_path / configs['datasets'][mode], 'rb'))
#print("Contents of 'patches':", patches)
#print("Length of patches dictionary:", len(patches))
#print("Keys in patches dictionary:", patches.keys())
event_id = 0
self.events = {}
# Keep the positive indices in a separate list (useful for under/oversampling)
self.positives_idx = []
# Load the data paths into a dictionary
for k in sorted(list(patches.keys())):
# Load a MODIS and a Sentinel image both for pre and post
try:
if 'LR_image' in patches[k].keys():
self.events[event_id] = {'MODIS': patches[k]['LR_image'], 'SEN2': patches[k]['HR_image'], 'Event_ID': patches[k]['EVENT_ID']}
self.events[event_id]['key'] = k
event_id += 1
if patches[k]['positive_flag']:
self.positives_idx.append(event_id)
event_id += 1
except:
print(k)
exit(1)
#print(self.events)
self.selected_bands = {}
self.means = {}
self.stds = {}
for k, v in sgd.items():
self.selected_bands[k] = configs['datasets']['selected_bands'][k].values()
self.means[k] = [m for i, m in enumerate(configs['datasets'][f'{k}_mean'][v])]
self.stds[k] = [m for i, m in enumerate(configs['datasets'][f'{k}_std'][v])]
Instead I got a KeyError regarding non-existent values, and each time I run it I get different values that indeed don’t exist. From the error it seems to me that it may have to do with the values of event_id
File "/tmp/ipykernel_3545567/1738999726.py", line 268, in __getitem__ batch = self.events[event_id] KeyError: 14903
I created some functions in order to deal with such issues but they don’t seem to be functioning since I keep getting KeyErrors
def load_img(self, sample):
'''
Loads an image.
'''
if len(self.source_types) > 1:
loaded_sample = {}
loaded_sample['MODIS'] = torch.load(sample['MODIS']).to(torch.float32)
loaded_sample['SEN2'] = torch.load(sample['SEN2']).to(torch.float32)
else:
loaded_sample['img'] = torch.load(sample['img']).to(torch.float32)
loaded_sample['key'] = f'{sample["key"]}_{sample["type"]}'
if sample['type'] == 'before':
# For the pre-fire image label, we keep only the 0 and 2 labels
before_lbl = torch.load(sample['label']).to(torch.long)
pos_idx = before_lbl == 1
before_lbl[pos_idx] = 0
loaded_sample['label'] = before_lbl
else:
loaded_sample['label'] = torch.load(sample['label']).to(torch.long)
return loaded_sample
def fillna(self, sample):
'''
Fills NaN values in the sample with the constant specified in the config.
'''
filled_sample = sample.copy()
for sample_name, s in sample.items():
if ('label' in sample_name) or ('key' in sample_name): continue
filled_sample[sample_name] = torch.nan_to_num(s, nan=self.configs['datasets']['nan_value'])
return filled_sample
def augment(self, sample):
'''
Applies the following augmentations:
- Random horizontal flipping (possibility = 0.5)
- Random vertical flipping (possibility = 0.5)
- Random Gaussian blurring (kernel size = 3) [only in train mode]
'''
aug_sample = sample.copy()
# Horizontal flip
if random.random() > 0.5:
for sample_name, sample in aug_sample.items():
if 'key' in sample_name: continue
aug_sample[sample_name] = TF.hflip(sample)
# Vertical flip
if random.random() > 0.5:
for sample_name, sample in aug_sample.items():
if 'key' in sample_name: continue
aug_sample[sample_name] = TF.vflip(sample)
# Gaussian blur
if (self.mode == 'train') and (random.random() > 0.5):
for sample_name, sample in aug_sample.items():
if ('label' in sample_name) or ('key' in sample_name): continue
aug_sample[sample_name] = TF.gaussian_blur(sample, 3)
return aug_sample
def __len__(self):
return len(self.events)
def __getitem__(self, event_id):
batch = self.events[event_id]
# Load images
batch = self.load_img(batch)
# Replace NaN values with constant
batch = self.fillna(batch)
# Normalize images
if (self.configs['datasets']['scale_input_sen2'] is not None) or (self.configs['datasets']['scale_input_mod'] is not None):
batch = self.scale_sample(batch)
# Some channels contain a single value (invalid) so scaling returns all NaN
batch = self.fillna(batch)
# Augment images
if self.augmentation:
batch = self.augment(batch)
# Downsample MODIS images if needed
if ('mod' in self.source_types) and self.configs['datasets']['original_modis_size']:
batch = self.downsample_modis(batch)
return batch
dry_martini is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.