Understanding Failure Modes of GAN Training
The idea of two competing neural networks is no doubt interesting; where, at each step one of them attempts to defeat the other one and in the process, both networks keep getting better at their job. But building such a dynamic training system is not always feasible. Generative Adversarial Networks, or GANs, are really difficult to train in practice. GANs usually suffer from two common types of failure modes during training. It is really important to understand these failure modes before training a GAN based model otherwise, we might not be able to develop and train a stable GAN based model. This article will mainly focus on explaining those two common failure modes in details, along with python examples. Once we are able to understand these failure modes, it will be easier for us to start developing and training our own stable GAN based models to generate the data of our choosing.
This article covers the following topics:
- Why GAN training is unstable?
- What are GAN failure modes?
- What is mode collapse in GAN training?
- What is convergence failure in GAN training?
- Summary
Let’s get started.
1. Why GAN training is unstable?
Generative Adversarial Networks, or GANs for short, are quite difficult to train in practice. GAN training can be visualised as a two-player zero sum game. The game between the generator and the discriminator is non-cooperative which means that improvements in one model come at the expense of the other model. This kind of setup is only meaningful (or stable) when both players are capacity-wise balanced. For example: If one of the players is very strong, it will dominate the game and will not let the weaker player gain anything. So, basically the weaker model might not be able to learn anything useful. And thus, the balance between two competing models is very important.
Finding this kind of balance between the generator and the discriminator model can be related to the problem of finding Nash Equilibrium (a game theory concept that determines the optimal solution in a non-cooperative game). A Nash Equilibrium occurs when each player has the minimal cost. Both the players in a GAN setup are neural networks (functions of high-dimensional parameter space), parameters are continuous variables, and the cost functions are non-convex. Finding the Nash Equilibrium in such a setup is not trivial. Also, there is no known algorithm for the same.
The key idea of keeping each player at minimum cost (Nash Equilibrium) motivates the use of traditional gradient descent-based learning methods. These learning methods can be leveraged for minimising the cost of both players simultaneously. But in our GAN setup, modification in one player’s cost using gradient descent, might influence the other player in a wrong way as well. This is the key reason that leads to the convergence failure of GANs very frequently.
If you are interested in learning more about the generative learning and Generative Adversarial Networks, Do check out my book:
As we can see that GANs are very unstable and difficult to train. Thus, it is important to study these common failure modes before developing GANs. Let’s learn about the common GAN failure modes.
2. GAN failure modes
The best way to check whether the GAN training is successful or not, is to monitor the generated samples very carefully and frequently (along with model loss). After looking at the generated content for a few iterations of training, an experienced AI practitioner will know whether the training is going in the right direction or not. If the training is not going in the right direction, one should stop the training and make required changes to the model and start training again.
Understanding Failure Modes of GAN Training
Before starting to develop or train GANs, it is important to understand the common failure modes of GAN training. Although, there can be lot many reasons behind the training failure of a GAN like setup, and sometimes they may even be related to the training dataset as well. It is really difficult to identify the exact reason of failure every time. However, the researchers have observed the following two failure modes to be occurring very frequently in case of GANs.
- Mode Collapse
- Convergence Failure
Let’s learn more about these failure modes.
3. Mode Collapse in GANs
Whenever we train a GAN based generative model, we expect it to generate wide variety of realistic samples. But it’s not always the case. Sometimes, due to an unbalanced training setup, the generator network fails to learn the correct representations and thus, fails to generate large variety of content. This problem with the training of GANs is known as ‘mode collapse’. Mode collapse is quite easy to detect but very difficult to solve. Basically, mode collapse means that the generator model collapses to generate only a limited variety of samples. This failure can be easily detected by inspecting the generated samples during the training phase. If there is very little diversity in the generated samples and some of the generated samples are exactly similar, then there are high chances of mode collapse.
Sometimes, all of the generated images are exactly same. Such scenario is known as a complete collapse. And the case when the generator collapses to a small variety of samples is termed as a partial collapse of GAN. Partial collapse is quite common whereas we don’t get to see a complete collapse very often. Let’s learn about the reasons behind a mode collapse.
3.1 Reasons behind Mode Collapse
Mode Collapse happens when the generator network fails to generate a wide variety of content. If the generator network is having a mode collapse and as a result it is producing similar kind of samples again and again, Then, the duty of an ideal discriminator network is to learn to reject those repeated samples as fake. And thus, it should make the generator put some more work into generating samples. But it is not always the case. Sometimes, the discriminator network gets stuck in a local minimum and fails to get out of it in future training steps, the generator takes advantage of this situation and learns to generate only a few variations of images which are good enough to fool the discriminator. In other words, both the players keep optimising to exploit the short-term weakness of the opponent.
A generator network collapsing for a small variety of samples, even when the input noise vector (z) is always different, indicates that the generator output is independent of z (in simple words, the gradient associated with z approaches to zero). Now that we have built an intuition behind the most probable reasons of model collapse, let’s see how can we solve it.
3.2 Solving Mode Collapse
Mode collapse, a very common problem leading to the failure of GANs, has led the researchers to find out multiple different ways to solve it. However, we can’t confidently claim that a particular hack would work in a given situation. Following is a list of some of the common hacks (or best practices) to solve mode collapse and make the GANs output wide variety of samples:
- Wasserstein GAN solves this problem with the help of a new loss function that gives the flexibility of training the discriminator to optimality without worrying about the vanishing gradients problem. We will learn about Wasserstein GANs later in this book.
- Unrolled GANs use a different loss function for the generator that doesn’t let it over-optimize for a particular discriminator state. This setup takes feedback from multiple discriminators while updating the generator parameters.
- Increasing the dimensions of the input vector (z) can introduce more variety to the generated content.
- Making the generator more complex (or deeper), such that it is capable of learning the complex representations/features.
- Impairing the discriminator by randomly giving some false labels.
The hacks listed above have proved to be useful in solving many mode collapse cases as per the study conducted by researchers. We may not be able to tell which hack might solve a given mode collapse problem, but with these hacks, our we should be able to solve most of our mode collapse related issues. Next, let’s learn about another common failure mode of GANs: Convergence Failure.
4. Convergence Failure in GANs
Convergence failure is a scenario when GAN training fails to converge without even producing any meaningful content. GAN training is a two-player zero sum game where each player learns at the cost of the other one (at each iteration), such a setup needs both players to be chosen carefully in order to establish an equilibrium. If one of the players dominates, the whole GAN setup fails to converge. This kind of problem with the training of GANs is known as Convergence Failure. And in case of a convergence failure, GANs don’t produce any meaningful output.
Convergence failure is quite frequently observed during the training of GANs. Following are the ways to identify convergence failure:
- The loss of the discriminator network rapidly decreases to a value very close to zero and stays there during the remaining training iterations.
- The samples generated by the generator network are very bad quality and it’s a trivial job for the discriminator to flag them as fake.
- The loss of the generator network either reaches zero or keeps continuously increasing during the entire training phase.
If our GAN training fails and shows the behavior related to any of the above-described scenarios, our model is most probably having a convergence failure situation. Next, let’s learn about the main reasons behind the convergence failure of GAN training.
4.1 Reasons behind Convergence Failure
Convergence failure happens when there is no balance between the generator and the discriminator networks. There can be two possible reasons behind this imbalance between the two networks. Following are those two possible scenarios:
- The first scenario is observed when the discriminator is weak, and the generator dominates. In this case, the discriminator is incapable of distinguishing real samples from the fake ones and thus it is not able to provide useful feedbacks to the generator. The generator loss reaches close to zero despite the bad quality of samples generated by it.
- Another possible scenario is when the discriminator is too good at its job. It means that the samples generated by the generator are always rejected and the generator is never able to fool the discriminator during the entire training process. The generator loss in such cases, keeps increasing during training and it produces garbage outputs.
We now understand the key reasons behind the convergence failure of a GAN training setup. Let’s learn about some ways of solving it.
4.2 Solving Convergence Failure
There are quite a few tricks to solve the convergence failure problem of GANs. All we have to do is: make the generator and discriminator networks balanced, such that neither of them dominates during the training. Let’s learn about some ways of balancing the two networks.
Solving for the first scenario where the Generator dominates:
- Make the discriminator deeper, thus better at rejecting the fake samples and making the generator put more work.
- Impairing the generator by adding dropout layers, thus it learns slower and in sync with the discriminator network.
- Making the generator weaker by removing some layers.
Solving for the second scenario where the Discriminator dominates:
- Impairing the discriminator by randomly giving some false labels to real images.
- Impairing the discriminator by adding dropouts (or other regularization techniques).
- Making the generator deeper, thus better at learning representations/features and producing realistic samples.
- Making the discriminator weaker by removing few layers, thus it learns slowly and in sync with the generator.
We have just learned about multiple tricks to balance the models for each scenario of convergence failure. All we have to do here is to balance the learning process of both models so that they learn in sync and keep improving each other. Now let’s jump into some hands-on experiments.
Conclusion
This article highlights the fact that GANs are quite difficult to train. GANs usually suffer from two common training failure scenarios: model collapse and convergence failure. It is quite important to learn to identify and overcome these failure modes before starting to work on GANs. After reading this article, the readers should be confident about answering the following questions about GAN failure modes:
- Why GANs are hard to train? What are the common problems they suffer from during training?
- How to identify the mode collapse failure during the GAN training? What are the different tricks to solve it?
- What are the common reasons behind the convergence failure of a GAN? How can we identify it during the training process? What are the tips and tricks to avoid and solve this issue?
Read Next>>>
- What are Autoregressive Generative Models?
- Building Blocks of Deep Generative Models
- Generative Learning and its Differences from the Discriminative Learning
- How Does a Generative Learning Model Work?
- Image Synthesis using Pixel CNN based Autoregressive Generative Models
- AutoEncoders in Keras and Deep Learning
- Optimizers explained for training Neural Networks
- Optimizing TensorFlow models with Quantization Techniques
Pingback: Best Practices for training stable GANs - Drops of AI