Image Synthesis using Pixel CNN based Autoregressive Generative Models
Recent advances in the field of deep learning have led to the development of complex generative models that are capable of generating high quality content in the form of text, audio, pictures, videos and so on. Generative models that make use of deep learning architectures to tackle the task of learning distributions, are known as deep generative models. Due to the flexibility and scalability of neural networks, deep generative models have become the most exciting and swiftly evolving field of ML and AI. Deep generative modeling techniques have helped in developing modern AI agents that are constantly generating and processing vast amounts of data.
Generative learning with deep neural networks is creating wonders today. There are quite a few popular deep generative learning frameworks that are an active area of research. In this article we will learn about one such generative learning framework – Autoregressive Generative Models. We will also develop a Pixel CNN based autoregressive generative model from scratch.
Let’s get started.
Autoregressive Generative Models
The term ‘autoregressive’ is taken from the field of time-series forecasting where a model considers the past observations in an ordered or timely manner, to make predictions about the future. Autoregressive generative models are also quite similar in nature, as they also take help from all of their past predictions for deciding the subsequent prediction.
For example: an autoregressive model may generate an image by predicting one pixel at a time. Here each new pixel value is estimated based on the previously predicted or observed pixel values. In statistical terms, an autoregressive model is trained to learn the conditional distribution of each individual pixel given the surrounding pixels that are already known (or have already been estimated by the same model in earlier iterations).
Pixel CNN, Pixel RNN, Character CNN, Character RNN and Wave-Net are some popular examples of the Deep Autoregressive Generative models. Some of these models have also been successfully applied in the field of anomaly detection and adversarial attacks detection.
Pros and Cons of Autoregressive Generative Models
Following is a list of some pros of autoregressive generative models:
- Easy to understand and calculate likelihoods
- Training process is supervised and very straight-forward
To get a more intuitive understanding of the autoregressive generative models, check out my article on : What are Autoregressive Generative Models?
Let’s now look at some of the cons of autoregressive approach:
- Requires an ordering of random variables
- Generation process is sequential, hence very slow
- High likelihood does not guarantee better looking samples in practice
- Does not learn unsupervised features (or representations)
Now that we know the basic idea behind autoregressive generative models, let’s develop a Pixel CNN model from scratch for image synthesis.
Image Synthesis using Pixel CNN
In this example, we will implement and train an autoregressive generative model called ‘Pixel CNN’ on MNIST handwritten digits dataset and verify its generative capabilities.
The code related to this experiment can be found here on Github.
This example has the following steps:
- Importing Libraries
- Download and show data
- Data Preparation
- Define Masked CNN layers
- Define Pixel-CNN
- Training Pixel-CNN Model
- Results
Let’s get started.
Step 1: Importing Libraries
As a first step, we will import some useful python libraries in a Jupyter Notebook cell. See the following snippet.
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
import tensorflow
print (tensorflow.__version__)
We will be using TensorFlow library with version 2.4.1 for our experiments.
Out[1]: 2.4.1
Now, let’s download the dataset.
Step 2: Download and show data
In this step, we will download the MNIST Handwritten Digits dataset. This is a very common dataset and thus it is also available in keras datasets. The following python code downloads the data and plot’s some handwritten digit images to verify the samples.
from tensorflow.keras.datasets import mnist
(trainX, trainy), (testX, testy) = mnist.load_data()
print('Training data shapes: X=%s, y=%s' % (trainX.shape, trainy.shape))
print('Testing data shapes: X=%s, y=%s' % (testX.shape, testy.shape))
for k in range(9):
plt.figure(figsize=(9,6))
for j in range(9):
i = np.random.randint(0, 10000)
plt.subplot(990 + 1 + j)
plt.imshow(trainX[i], cmap='gray_r')
plt.axis('off')
#plt.title(trainy[i])
plt.show()
Here is the output:
Figure 1 gives a basic idea about the MNIST handwritten digits dataset. It has around 60k training images and 10k test images each with dimensions: 28 x 28.
Now, let’s prepare the dataset for the model.
Step 3: Data Preparation
In this step, we will prepare our dataset for the model training purpose. Firstly, we will convert these images into binary format (as expected by our basic Pixel CNN). Secondly, we will add a channel dimension to all the images, as expected by the Convolutional layers in our Pixel CNN.
trainX = np.where(trainX < (0.33 * 256), 0, 1)
train_data = trainX.astype(np.float32)
testX = np.where(testX < (0.33 * 256), 0, 1)
test_data = testX.astype(np.float32)
train_data = np.reshape(train_data, (60000, 28, 28, 1))
test_data = np.reshape(test_data, (10000, 28, 28, 1))
print (train_data.shape, test_data.shape)
Following is the final shape of our dataset after pre-processing:
Out[3]: (60000, 28, 28, 1) (10000, 28, 28, 1)
Our data is now ready. We can now define the Pixel CNN architecture.
Step 4: Define Masked CNN layers
In this step, we will define some building blocks of our Pixel CNN framework. First, we will define a PixelConvLayer class. A Pixel CNN layer is simply built on top of 2D Convolutional layers. Only difference is that it includes masking in order to accommodate the autoregressive training of the network. In simple words, autoregressive models generate output one pixel at a time, so, masking is a clever way to train Pixel CNN with regular CNN layers.
See the python code below for Pixel CNN layer implementation:
class PixelConvLayer(tensorflow.keras.layers.Layer):
def __init__(self, mask_type, **kwargs):
super(PixelConvLayer, self).__init__()
self.mask_type = mask_type
self.conv = tensorflow.keras.layers.Conv2D(**kwargs)
def build(self, input_shape):
# Build the conv2d layer to initialize kernel variables
self.conv.build(input_shape)
# Use the initialized kernel to create the mask
kernel_shape = self.conv.kernel.get_shape()
self.mask = np.zeros(shape=kernel_shape)
self.mask[: kernel_shape[0] // 2, ...] = 1.0
self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
if self.mask_type == "B":
self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
def call(self, inputs):
self.conv.kernel.assign(self.conv.kernel * self.mask)
return self.conv(inputs)
Secondly, we will define a residual block layer. It is also a normal residual block, which is based on the Pixel CNN layer. See the following code:
class ResidualBlock(tensorflow.keras.layers.Layer):
def __init__(self, filters, **kwargs):
super(ResidualBlock, self).__init__(**kwargs)
self.conv1 = tensorflow.keras.layers.Conv2D(
filters=filters, kernel_size=1, activation="relu"
)
self.pixel_conv = PixelConvLayer(
mask_type="B",
filters=filters // 2,
kernel_size=3,
activation="relu",
padding="same",
)
self.conv2 = tensorflow.keras.layers.Conv2D(
filters=filters, kernel_size=1, activation="relu"
)
def call(self, inputs):
x = self.conv1(inputs)
x = self.pixel_conv(x)
x = self.conv2(x)
return tensorflow.keras.layers.add([inputs, x])
We now have the building blocks ready for our Pixel CNN. We can now go ahead and define the final Pixel CNN architecture.
Step 5: Define Pixel CNN
In this step, we will define the final Pixel CNN model architecture. This architecture is very similar to the one shown in original Pixel CNN paper. It combines Pixel CNN layers with residual blocks as shown in the following python code:
inputs = tensorflow.keras.Input(shape=(28,28,1))
x = PixelConvLayer(
mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
)(inputs)
for _ in range(5):
x = ResidualBlock(filters=128)(x)
for _ in range(2):
x = PixelConvLayer(
mask_type="B",
filters=128,
kernel_size=1,
strides=1,
activation="relu",
padding="valid",
)(x)
out = tensorflow.keras.layers.Conv2D(
filters=1, kernel_size=1, strides=1, activation="sigmoid", padding="valid"
)(x)
pixel_cnn = tensorflow.keras.Model(inputs, out)
pixel_cnn.summary()
Following is the final summary of our Pixel CNN model. It has roughly 500k trainable parameters.
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
pixel_conv_layer (PixelConvL (None, 28, 28, 128) 6400
_________________________________________________________________
residual_block (ResidualBloc (None, 28, 28, 128) 98624
_________________________________________________________________
residual_block_1 (ResidualBl (None, 28, 28, 128) 98624
_________________________________________________________________
residual_block_2 (ResidualBl (None, 28, 28, 128) 98624
_________________________________________________________________
residual_block_3 (ResidualBl (None, 28, 28, 128) 98624
_________________________________________________________________
residual_block_4 (ResidualBl (None, 28, 28, 128) 98624
_________________________________________________________________
pixel_conv_layer_6 (PixelCon (None, 28, 28, 128) 16512
_________________________________________________________________
pixel_conv_layer_7 (PixelCon (None, 28, 28, 128) 16512
_________________________________________________________________
conv2d_18 (Conv2D) (None, 28, 28, 1) 129
=================================================================
Total params: 532,673
Trainable params: 532,673
Non-trainable params: 0
_________________________________________________________________
Our Pixel CNN architecture is now ready. We will utilize Adam optimizer with a learning rate of 0.0005 for updating the weights of our model during training. Also, we will utilize ‘binary_crossentropy’ as a loss function for our model.
Remember, Pixel CNN generates image one pixel at a time, and our setup considers binary images with values 0 and 1. Thus, ‘binary_crossentropy’ loss makes sense here.
Compiling Pixel CNN:
adam = tensorflow.keras.optimizers.Adam(learning_rate=0.0005)
pixel_cnn.compile(optimizer=adam, loss="binary_crossentropy")
Our Pixel CNN is now ready for training.
Step 6: Training Pixel-CNN Model
In this step, we will train the Pixel CNN model on MNIST Handwritten Digits dataset. We will train our model for 50 epochs with a batch size of 128.
pixel_cnn.fit(
x=train_data, y=train_data, batch_size=128, epochs=50, validation_data=(test_data, test_data), verbose=1)
Following are the training logs. We can see that the validation loss is decreasing consistently, which proves that the model is learning:
Epoch 1/50
469/469 [==============================] - 57s 116ms/step - loss: 0.1748 - val_loss: 0.0923
Epoch 2/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0918 - val_loss: 0.0892
Epoch 3/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0894 - val_loss: 0.0882
Epoch 4/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0881 - val_loss: 0.0873
. . . . .
. . . . .
. . . . .
. . . . .
Epoch 45/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0815 - val_loss: 0.0829
Epoch 46/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0815 - val_loss: 0.0827
Epoch 47/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0813 - val_loss: 0.0829
Epoch 48/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0812 - val_loss: 0.0829
Epoch 49/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0813 - val_loss: 0.0830
Epoch 50/50
469/469 [==============================] - 53s 113ms/step - loss: 0.0812 - val_loss: 0.0828
As our model training is complete now, we can now check the results.
Step 7: Results (Image Synthesis using Pixel CNN)
The following python code utilizes our Pixel CNN model for generating handwritten digit images. It also plots some of the generated results.
from IPython.display import Image, display
from tqdm import tqdm_notebook
# Create an empty array of pixels.
batch = 81
pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
batch, rows, cols, channels = pixels.shape
# Iterate over the pixels because generation has to be done sequentially pixel by pixel.
for row in tqdm_notebook(range(rows)):
for col in range(cols):
for channel in range(channels):
# Feed the whole array and retrieving the pixel value probabilities for the next
# pixel.
probs = pixel_cnn.predict(pixels)[:, row, col, channel]
# Use the probabilities to pick pixel values and append the values to the image
# frame.
pixels[:, row, col, channel] = tensorflow.math.ceil(
probs - tensorflow.random.uniform(probs.shape)
)
counter = 0
for i in range(9):
plt.figure(figsize=(9,6))
for j in range(9):
plt.subplot(990 + 1 + j)
plt.imshow(pixels[counter,:,:,0], cmap='gray_r')
counter += 1
plt.axis('off')
plt.show()
Figure 2 shows the output of our Pixel CNN model. We can see that the output images look very close to the actual handwritten digits. They are not perfect though, due to the limitations of the framework, capacity and dataset.
Yay! We have now successfully implemented, trained and validated the Pixel CNN model.
Conclusion
In this article, we discussed about a popular generative learning framework: Autoregressive Generative Models. We have also implemented a Pixel CNN based autoregressive generative model for image synthesis from scratch in python. Our results prove that Pixel CNN framework is capable of generating images one pixel at a time. Although the results are not perfect. These results can be improved by increasing the capacity of the model, increasing the size of the dataset, and also by using a better generative learning modeling framework such as – Generative Adversarial Networks (GANs), Variational AutoEncoders (VAEs), Diffusion based models and so on.
If you are interested in learning more about the generative learning and Generative Adversarial Networks, Do check out my book:
I hope this article was helpful. Please let me know your feedback/suggestions by commenting below.
See you in the next article!
Read Next>>>
- What are Autoregressive Generative Models?
- Building Blocks of Deep Generative Models
- Generative Learning and its Differences from the Discriminative Learning
- How Does a Generative Learning Model Work?
- Deep Learning with PyTorch: Introduction
- Deep Learning with PyTorch: First Neural Network
- AutoEncoders in Keras and Deep Learning (Introduction)
- Optimizers explained for training Neural Networks
- Optimizing TensorFlow models with Quantization Techniques
A.I Create & Sell Unlimited Audiobooks to 2.3 Million Users – https://ext-opp.com/ECCO