I am struggling to work out why I’m getting this error when trying to apply the lottery ticket hypothesis to my model. Clearly this is happening during the pruning callback, and it seems like it’s trying to save the previous weights as a parameter, but the value of a FloatTensor. I am quite at a loss of what to do.
File "/uufs/chpc.utah.edu/common/home/u0977428/.micromamba/lib/python3.9/site-packages/pytorch_lightning/callbacks/pruning.py", line 340, in apply_pruning
self._apply_local_pruning(amount)
File "/uufs/chpc.utah.edu/common/home/u0977428/.micromamba/lib/python3.9/site-packages/pytorch_lightning/callbacks/pruning.py", line 311, in _apply_local_pruning
self.pruning_fn(module, name=name, amount=amount)
File "/uufs/chpc.utah.edu/common/home/u0977428/.micromamba/lib/python3.9/site-packages/torch/nn/utils/prune.py", line 909, in l1_unstructured
L1Unstructured.apply(
File "/uufs/chpc.utah.edu/common/home/u0977428/.micromamba/lib/python3.9/site-packages/torch/nn/utils/prune.py", line 545, in apply
return super().apply(
File "/uufs/chpc.utah.edu/common/home/u0977428/.micromamba/lib/python3.9/site-packages/torch/nn/utils/prune.py", line 163, in apply
module.register_parameter(name + "_orig", orig)
File "/uufs/chpc.utah.edu/common/home/u0977428/.micromamba/lib/python3.9/site-packages/torch/nn/modules/module.py", line 583, in register_parameter
raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
TypeError: cannot assign 'torch.cuda.FloatTensor' object to parameter 'weight_orig' (torch.nn.Parameter or None required)
def main(config, ...):
model = UNetLightning(
...
)
trainer_args = {
"max_epochs": config.epochs,
"precision": config.precision,
"accelerator": config.accelerator,
"callbacks": [],
"gradient_clip_val": 1.0,
"accumulate_grad_batches": 3,
}
initial_state = model.state_dict() # Save initial state for pruning
pruning_passes = 3 # Number of pruning passes to make
for i in range(pruning_passes):
print(f"Pruning iteration {i + 1}/{pruning_passes}")
# Initialize the trainer with the pruning callback
pruning_callback = ModelPruning(
"l1_unstructured",
amount=0.2,
verbose=True,
use_global_unstructured=False
)
# Apply pruning before moving the model to the GPU
model.cpu()
pruning_callback.on_fit_start(trainer=None, pl_module=model)
trainer_args["callbacks"].append(pruning_callback)
trainer = Trainer(**trainer_args)
# Train the model
trainer.fit(model, data_module)
# Remove the pruning callback for the next iteration
trainer_args["callbacks"].remove(pruning_callback)
# Reset the model to initial weights
model.load_state_dict(initial_state)
# Final training pass without pruning to fine-tune the model
trainer_args["callbacks"].append(checkpoint_cb) # Save the best model after fine-tuning
trainer = Trainer(**trainer_args)
trainer.fit(model, data_module)
# Save the final model
torch.save(model.state_dict(), 'final_pruned_model.pth')