I am currently writing a physics informed neural network that learns the velocity, temperature, and pressure field from the evolution of a bubble in a pool boiling scenario. I have tried to implement the GradNorm regularization scheme to help convergence of the model as outlined here https://arxiv.org/abs/2308.08468 , but instead of helping convergence it prevents it.
Here is the implementation that is part of a larger class I have written in Jupyter Notebook.
def global_weight_update(self,w_pde,w_data,w_bc,w_ic,C=0.9):
bc_batch_size = self.North.shape[0] // self.num_batches
ic_batch_size = self.X0.shape[0] // self.num_batches
data_batch_size = self.Xdata.shape[0] // self.num_batches
bc_indices, ic_indices, data_indices = self.shuffle_data()
w1, w2, w3, w4 = 0, 0, 0, 0
for batch in range(self.num_batches):
self.optimizer_adam.zero_grad()
bc_batch = self.get_batches(batch,bc_batch_size,bc_indices)
ic_batch = self.get_batches(batch,ic_batch_size,ic_indices)
data_batch = self.get_batches(batch,data_batch_size,data_indices)
bc_loss = self.boundary(self.North[bc_batch], self.South[bc_batch],
self.East[bc_batch], self.West[bc_batch])
ic_loss = self.get_initial(self.X0[ic_batch], self.U0[ic_batch])
pde_loss, data_loss = self.physics(self.Xdata[data_batch], self.adata[data_batch])
w1 += self.get_norm_gradLoss(pde_loss)
w2 += self.get_norm_gradLoss(data_loss)
w3 += self.get_norm_gradLoss(bc_loss)
w4 += self.get_norm_gradLoss(ic_loss)
w1 /= self.num_batches
w2 /= self.num_batches
w3 /= self.num_batches
w4 /= self.num_batches
total = w1 + w2 + w3 + w4
w1 = total / w1
w2 = total / w2
w3 = total / w3
w4 = total / w4
w_pde_new = C * w_pde + (1-C)*w1
w_data_new = C * w_data + (1-C)*w2
w_bc_new = C * w_bc + (1-C)*w3
w_ic_new = C * w_ic + (1-C)*w4
return w_pde_new, w_data_new, w_bc_new, w_ic_new
The final loss function I want to optimize is this.
self.loss = w_pde * pde_loss + w_data * data_loss + w_bc * bc_loss + w_ic * ic_loss
The initial values at epoch 0 are all 1.0 for all weights. Does anyone have any suggestions besides changing the initial values of the loss weights? Thanks in advanced for any and all advice.
Logan is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.