I’m training a model to predict an angle in the [0, 2*pi) interval and I want to use a loss function that understands that we’re working on a circumference.
I defined the loss function cos_loss
, the idea of the penalty is to encourage the model to keep predictions in the desired interval.
def cos_loss(y_true, y_pred):
loss = 2 * (1 - tf.math.cos(y_true-y_pred))
penalty = tf.math.maximum(0., y_pred - 2* np.pi)
return tf.reduce_mean(loss + penalty, axis=-1)
optimizer=optimizers.Adam(learning_rate=0.004, decay=0.001)
metrics = ["mae","mse", root_mean_squared_error]
self.model.compile(
loss=cos_loss,
optimizer=optimizer,
metrics=metrics
)
But the training is failing as both training and validation loss remain the same (see log below) However if I use a standard loss function, like MSE, the model actually trains. I also tried other different “circular” function and none of the was capable of converging.
Epoch 1/100
2024-08-02 14:12:12,975 1814 [callbacks.py:61] : INFO: Epoch: 0 - loss: 1.99 val_loss: 2.39
17/17 [==============================] - 12s 398ms/step - loss: 1.9908 - mae: 2.9680 - mse: 12.1402 - root_mean_squared_error: 2.9680 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 2/100
2024-08-02 14:12:18,107 1814 [callbacks.py:61] : INFO: Epoch: 1 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 300ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 3/100
2024-08-02 14:12:23,498 1814 [callbacks.py:61] : INFO: Epoch: 2 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 325ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 4/100
2024-08-02 14:12:28,983 1814 [callbacks.py:61] : INFO: Epoch: 3 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 329ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 5/100
2024-08-02 14:12:34,369 1814 [callbacks.py:61] : INFO: Epoch: 4 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 10s 585ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 6/100
2024-08-02 14:12:44,881 1814 [callbacks.py:61] : INFO: Epoch: 5 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 6s 372ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 7/100
2024-08-02 14:12:50,794 1814 [callbacks.py:61] : INFO: Epoch: 6 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 6s 357ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 8/100
2024-08-02 14:12:56,460 1814 [callbacks.py:61] : INFO: Epoch: 7 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 6s 337ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 9/100
2024-08-02 14:13:01,771 1814 [callbacks.py:61] : INFO: Epoch: 8 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 10s 577ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 10/100
2024-08-02 14:13:11,118 1814 [callbacks.py:61] : INFO: Epoch: 9 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 301ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 11/100
2024-08-02 14:13:16,100 1814 [callbacks.py:61] : INFO: Epoch: 10 - loss: 1.98 val_loss: 2.39
17/17 [==============================] - 5s 293ms/step - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828 - val_loss: 2.3948 - val_mae: 3.3679 - val_mse: 13.8058 - val_root_mean_squared_error: 3.3679
Epoch 12/100
17/17 [==============================] - ETA: 0s - loss: 1.9792 - mae: 2.9828 - mse: 12.2488 - root_mean_squared_error: 2.9828
The model is a convolutional network + fully connected layer.
>>> model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
layer_0 (Reshape) (None, 64, 100, 4) 0
layer_1 (Conv2D) (None, 32, 34, 2) 51200
layer_2 (BatchNormalizatio (None, 32, 34, 2) 8
n)
layer_3 (Activation) (None, 32, 34, 2) 0
layer_4 (Conv2D) (None, 64, 17, 1) 18432
layer_5 (BatchNormalizatio (None, 64, 17, 1) 4
n)
layer_6 (Activation) (None, 64, 17, 1) 0
layer_7 (Conv2D) (None, 96, 9, 1) 55296
layer_8 (BatchNormalizatio (None, 96, 9, 1) 4
n)
layer_9 (Activation) (None, 96, 9, 1) 0
layer_10 (Conv2D) (None, 128, 5, 1) 110592
layer_11 (BatchNormalizati (None, 128, 5, 1) 4
on)
layer_12 (Activation) (None, 128, 5, 1) 0
layer_13 (Flatten) (None, 640) 0
layer_14 (Dense) (None, 1024) 655360
layer_15 (Dense) (None, 512) 524288
layer_16 (Dense) (None, 256) 131072
layer_17 (Dense) (None, 128) 32768
layer_18 (Dense) (None, 1) 128