I’m training latent diffusion with audio encoding of shape batch 16 * channel 256 * n_frame 501 * n_frequency 6.
Traceback (most recent call last):
File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 2236, in <module>
main()
File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 2218, in main
globals = debugger.run(setup['file'], None, None, is_module)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 1528, in run
return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 1535, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"n", file, 'exec'), glob, loc)
File "/home/szding/v2sa/v2sa/trainer/trainer_ldm.py", line 232, in <module>
main(config_yaml, arguments.exp_group_name, arguments.exp_name)
File "/home/szding/v2sa/v2sa/trainer/trainer_ldm.py", line 211, in main
trainer.fit(model, datamodule, ckpt_path=last_ckpt)
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage
self.fit_loop.run()
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
self._optimizer_step(batch_idx, closure)
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
call._call_lightning_module_hook(
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 1303, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py", line 152, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 239, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 122, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/optim/optimizer.py", line 391, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
ret = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/optim/adamw.py", line 165, in step
loss = closure()
^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 108, in _wrap_closure
closure_result = closure()
^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
self._result = self.closure(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 129, in closure
step_output = self._step_fn()
^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 318, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 391, in training_step
return self.lightning_module.training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/ldm.py", line 588, in training_step
loss, loss_dict = self.shared_step(batch)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/ldm.py", line 1262, in shared_step
loss, loss_dict = self(x, codes, self.filter_useful_cond_dict(c))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/ldm.py", line 1299, in forward
loss, loss_dict = self.p_losses(x, codes, c, t, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/ldm.py", line 1322, in p_losses
model_output = self.apply_model(x_noisy, t, cond)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/ldm.py", line 1312, in apply_model
x_recon = self.model(x_noisy, t, cond_dict=cond)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/ldm.py", line 1971, in forward
out = self.diffusion_model(
^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/szding/v2sa/v2sa/models/ldm/modules/openai_unetmodel.py", line 880, in forward
h = th.cat([h, concate_tensor], dim=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 252 but got size 251 for tensor number 1 in the list.
I checked the shape of h
and concate_tensor
:
The shape of UNet input is torch.Size([16, 256, 501, 6])
torch.Size([16, 640, 63, 1]) torch.Size([16, 640, 63, 1])
torch.Size([16, 640, 63, 1]) torch.Size([16, 640, 63, 1])
torch.Size([16, 384, 63, 1]) torch.Size([16, 640, 63, 1])
torch.Size([16, 384, 126, 2]) torch.Size([16, 640, 126, 2])
torch.Size([16, 384, 126, 2]) torch.Size([16, 384, 126, 2])
torch.Size([16, 256, 126, 2]) torch.Size([16, 384, 126, 2])
torch.Size([16, 256, 251, 3]) torch.Size([16, 384, 252, 4])