I tried to train a ViT-GAN model (from this repo) on my database, where i have images as input and output. The input image is a PNG map of a path planning problem. Red channel is obstacle map, green and blue is one pixel with start/goal. The output image would be the planned path on the red channel and this is what i am trying to teach the model. The model from the repo works nicely on the included examples and that is a much more complex problem than mine, so i naturally thought this network would solve my problem without much further tuning.
My issue is that even without touching the network, just replacing the training data to mine it completely messes up the output. I tried tweaking the network in so many ways, but the result always remains the same -> either the same structured garbage output for any given input (white patch on black), or some noisy rgb patches on black background. I just cant get it to learn the expected output.
On the left is the given input and on the right is the expected output. Middle is the output of the network. I read about training GAN’s on i get they are volatile and hard to find the right structure/parameter set, but i just cant figure out what to do next with it. What i already tried:
- Making the network smaller by removing some transformer/convulutional layers
- Lowering the number of filters hence my problem is not that complex
- Adjusting the learning rates of the generator/discriminator -> together and separately
- Changing the activation function of the output layer of the discriminator to sigmoid as suggested in another similar post
- Removed all ReLU activation functions and replaced them with LeakyReLU to address possible vanishing gradient problems
- Adding Wasserstein loss to the discriminator loss function to avoid model collapse
- Changing mean diff of pixels in the generator loss function to sum diff (to punish full black output more)
- Use bias to avoid vanishing gradient
- Changing the random initializer of filter from (0.0, 0.02) -> (0.0, 1.0) to make the initial filter more varied
- Changing the lambda, batch size, ff_dim, number of heads, patch size, embed dim, projection dim params
- Train for 200 epochs and for 20 also. All for the same output structure.
- Train on smaller portion of dataset (400 images), and on larger (1000 and 3000 images)
I am aware that this problem could be solved with other (maybe simplier) netwroks also, but i would like to make this GAN work and understand why it wouldn’t work in the first place. I added my current state of changes to this repo. I have no idea where to look next. If someone has some suggestions please feel free to share. If you suggest to adjust some params please write a concrete value. If i forgot anything from the description let me know.