top of page

Implementation of a Variational Autoencoder (VAE) from Scratch on MNIST Dataset

Writer's picture: RojanRojan

Deep learning has revolutionized the field of artificial intelligence, providing powerful tools to solve complex problems. Among these tools, Variational Autoencoders (VAEs) stand out for their ability to generate new data that's similar to the training data. In this post, we'll dive into the basics of VAEs and demonstrate how to implement one from scratch using TensorFlow, specifically focusing on the MNIST dataset of handwritten digits.


Understanding Variational Autoencoders


Introduction to Variational Autoencoders

Variational Autoencoders (VAEs) stand out as a fascinating intersection of deep learning and Bayesian inference. Unlike traditional autoencoders, VAEs not only aim to encode and decode data but also to learn the underlying probability distribution of the data. This capability makes VAEs extremely powerful for generative tasks—like creating new images or reconstructing missing data.


The Architecture of a VAE

A VAE comprises three main components: the encoder, the latent space, and the decoder. Let's dive into each component and see how they contribute to the functioning of a VAE.



1.Encoder

The encoder's role is to process the input data (denoted as x) and map it to a latent distribution characterized by parameters—mean (μ) and variance (σ). These parameters describe how data points are distributed in the latent space.


2. Latent Space

The latent space is where the VAE learns to encode its input as a distribution rather than a fixed point. The real magic happens here: using the parameters from the encoder, the VAE samples a point z from this distribution using the reparameterization trick—z = μ + σ ⊙ ε, where ε is a random sample from a standard normal distribution. This trick allows the model to backpropagate gradients through stochastic nodes, making training feasible.


3. Decoder

The decoder takes the sampled latent points z and works to reconstruct the input data x into . The quality of reconstruction is critical as it determines how well the VAE has learned to reproduce the input data.



Understanding the Loss Function

The VAE's effectiveness hinges on a carefully designed loss function, which has two main components:


  • Reconstruction Loss: This component (Minimize 1: (x - x̂)²) measures how accurately the reconstructed data  matches the original input x. A lower reconstruction loss means better performance of the decoder.


  • KL Divergence Loss: This component (Minimize 2: 1/2 Σ (exp(σ_i) - (1+σ_i) + μ_i²)) measures how much the learned distribution (from the encoder) deviates from a predefined distribution, typically a standard normal distribution. This loss ensures that the distributions in the latent space do not stray too far from the norm, acting as a regularizer.


The MNIST Dataset

The MNIST dataset is a classic in the machine learning community, consisting of 70,000 28x28 pixel images of handwritten digits. It's widely used for training and testing in the field of machine learning, making it an ideal candidate for our VAE implementation.



Setting Up the Environment

Before diving into the code, ensure TensorFlow is installed in your environment. TensorFlow offers a comprehensive ecosystem of tools and libraries for building and deploying machine learning models.


import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras.utils import plot_model
import numpy as np
import gzip
from urllib import request
import matplotlib.pyplot as plt

Loading and Preprocessing the Data

Our first step involves downloading and loading the MNIST dataset, then preprocessing it for our model:

# Download the files
url = "http://yann.lecun.com/exdb/mnist/"
filenames = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
             't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
data = []
for filename in filenames:
    print("Downloading", filename)
    request.urlretrieve(url + filename, filename)
    with gzip.open(filename, 'rb') as f:
        if 'labels' in filename:
            # Load the labels as a one-dimensional array of integers
            data.append(np.frombuffer(f.read(), np.uint8, offset=8))
        else:
            # Load the images as a two-dimensional array of pixels
            data.append(np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28*28))
# Split into training and testing sets
X_train, y_train, X_test, y_test = data

# Normalize the pixel values
X_train = X_train.astype(np.float32) / 255.0
X_test = X_test.astype(np.float32) / 255.0

We've got our training and testing set, so now we can implement an architecture. First, let's look at our images:

# Function to display images and their labels
def show_images(images, labels):
    """
    Display a set of images and their labels using matplotlib.
    """
    fig, axs = plt.subplots(ncols=len(images), nrows=1, figsize=(10, 3 * len(images)))
    for i in range(len(images)):
        axs[i].imshow(images[i], cmap="gray")
        axs[i].set_title("Label: {}".format(labels[i]))
        axs[i].set_xticks([])
        axs[i].set_yticks([])
        axs[i].set_xlabel("Index: {}".format(i))
    fig.subplots_adjust(hspace=0.5)
    plt.show()

# Select and display the first 10 images and labels
images = X_train[:10] # First 10 images
labels = y_train[:10] # First 10 labels
show_images(images.reshape(-1, 28, 28), labels)

Building the Autoencoder

We start with a simple autoencoder model to establish a baseline:

# Define the Autoencoder model using TensorFlow
class AutoEncoder(tf.keras.Model):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Set the number of hidden units
        self.num_hidden = 8

        # Define the encoder part of the autoencoder
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),  # Flatten the input
            tf.keras.layers.Dense(256, activation='relu'),  # Fully connected layer with ReLU activation
            tf.keras.layers.Dense(self.num_hidden, activation='relu')  # Fully connected layer with ReLU activation
        ])

        # Define the decoder part of the autoencoder
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu'),  # Fully connected layer with ReLU activation
            tf.keras.layers.Dense(784, activation='sigmoid'),  # Fully connected layer with sigmoid activation
            tf.keras.layers.Reshape((28, 28))  # Reshape back to image dimensions
        ])

    def call(self, x):
        # Pass the input through the encoder
        encoded = self.encoder(x)
        # Pass the encoded representation through the decoder
        decoded = self.decoder(encoded)
        return encoded, decoded

# Initialize the model
autoencoder = AutoEncoder()

# Define optimizer and loss function
optimizer = tf.keras.optimizers.Adam()
mse_loss_fn = tf.keras.losses.MeanSquaredError()

# Training loop
epochs = 5
batch_size = 32
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    for step in range(len(X_train) // batch_size):
        x_batch_train = X_train[step * batch_size : (step + 1) * batch_size]
        x_batch_train = tf.reshape(x_batch_train, [-1, 28, 28])  # Reshape the input batch
        with tf.GradientTape() as tape:
            _, decoded = autoencoder(x_batch_train)
            loss = mse_loss_fn(x_batch_train, decoded)
        grads = tape.gradient(loss, autoencoder.trainable_variables)
        optimizer.apply_gradients(zip(grads, autoencoder.trainable_variables))

    print("Epoch %d: Loss value for training set: %.4f" % (epoch, float(loss)))

# Reconstruction on test set
_, decoded_images = autoencoder(X_test.reshape(-1, 28, 28))  # Reshape test set images
show_images(decoded_images.numpy()[:10], y_test[:10])

# Reconstruction on test set
_, decoded_images = autoencoder(X_test.reshape(-1, 28, 28))
show_images(decoded_images.numpy()[:10], y_test[:10])

Here's the output:


Advancing to Variational Autoencoder

Next, we extend our model to a Variational Autoencoder:

# Variational Autoencoder (VAE) definition
class VAE(tf.keras.Model):
    def __init__(self):
        super(VAE, self).__init__()
        self.num_hidden = 8
        self.encoder = Sequential([
            Flatten(input_shape=(28, 28)),
            Dense(256, activation='relu'),
            Dense(self.num_hidden * 2)  # For mean and log variance
        ])
        self.decoder = Sequential([
            Dense(256, activation='relu'),
            Dense(784, activation='sigmoid'),
            Reshape((28, 28))
        ])

    def encode(self, x):
        encoded = self.encoder(x)
        mean, logvar = tf.split(encoded, num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        epsilon = tf.random.normal(shape=mean.shape)
        return epsilon * tf.exp(logvar * 0.5) + mean

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        decoded = self.decode(z)
        return decoded, mean, logvar

# VAE model initialization, optimizer, and loss function
vae = VAE()
optimizer = tf.keras.optimizers.Adam()
def vae_loss(x, decoded, mean, logvar):
    reconstruction_loss = tf.reduce_mean(
        tf.keras.losses.binary_crossentropy(tf.keras.backend.flatten(x),
                                            tf.keras.backend.flatten(decoded)))
    kl_loss = -0.5 * tf.reduce_mean(1 + logvar - tf.square(mean) - tf.exp(logvar))
    return 1.75 * reconstruction_loss + kl_loss * 0.1

Let's look at Variational Autoencoder (VAE) encoder architecture.

# Visualize the encoder
plot_model(vae.encoder, to_file='vae_encoder_plot.png', show_shapes=True, show_layer_names=True)

Next, Let's look at Variational Autoencoder (VAE) decoder architecture.

# Explicitly build the decoder model with the expected input shape
vae.decoder.build((None, vae.num_hidden))  # 'None' can be used for batch size flexibility

# Now you can plot the decoder model
plot_model(vae.decoder, to_file='vae_decoder_plot.png', show_shapes=True, show_layer_names=True)

Training and Evaluation

With our model defined, we proceed to train it and evaluate its performance:

# Training loop for the VAE
epochs = 10
batch_size = 32
vae_loss_values = []
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    total_loss = 0
    for step in range(len(X_train) // batch_size):
        x_batch_train = X_train[step * batch_size : (step + 1) * batch_size]
        x_batch_train = tf.reshape(x_batch_train, [-1, 28, 28])
        with tf.GradientTape() as tape:
            decoded, mean, logvar = vae(x_batch_train)
            loss = vae_loss(x_batch_train, decoded, mean, logvar)
        grads = tape.gradient(loss, vae.trainable_variables)
        optimizer.apply_gradients(zip(grads, vae.trainable_variables))
        total_loss += loss.numpy()
    avg_loss = total_loss / (len(X_train) // batch_size)
    vae_loss_values.append(avg_loss)
    print("Epoch %d: Loss value for training set: %.4f" % (epoch, float(loss)))

# Training loop and reconstruction on test set...



Visualizing the Results

To understand the quality of our VAE, we visualize the reconstructed images alongside their original counterparts. Additionally, we employ t-SNE, a technique for dimensionality reduction, to visualize the latent space learned by our VAE:

# Reconstruction on test set
decoded_images, _, _ = vae(X_test.reshape(-1, 28, 28))  # Reshape test set images
show_images(decoded_images.numpy()[:10], y_test[:10])

Plotting the training loss of a Variational Autoencoder (VAE) over several epochs:

plt.figure(figsize=(10, 5))
plt.plot(vae_loss_values, label='VAE Loss')
plt.title('Training Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

Conclusion

Through this exercise, we've seen how to implement a Variational Autoencoder from scratch with TensorFlow and apply it to the MNIST dataset. VAEs offer a fascinating glimpse into the world of generative models, capable of learning deep representations of data and generating new instances from those representations. By experimenting with different architectures and datasets, you can explore the full potential of VAEs in generating diverse and complex data.



39 views

Comentários

Avaliado com 0 de 5 estrelas.
Ainda sem avaliações

Adicione uma avaliação
  • X
  • Linkedin
  • Youtube
  • Google_Scholar_logo.svg
  • photo_5803144542755600737_c

© 2021 by Faezeh Maghsoodifar. Powered and secured by Wix

bottom of page