Training an image captioning model with visual attention. An example of combining 2 modalities into a single model.
This post is a summary of this Google tutorial.
NOTE: This post is intended for developers, if you are an aspiring data scientist or AI researcher this post will not dig deep enough for you.
Overview
Image captioning models take an image as input and generate text describing the image.
The challenge in these type of models is how to bring together the visual space with the textual space. We need to bring these 2 modalities into a common ground.
One way of doing this is by using an Encoder-Decoder architecture.
- Encoder - Takes an image as input and outputs an embeddings vector, a numeric representation capturing the "essence" of the image.
- Decoder - Responsible for generating text with respect to the embeddings it got from the encoder. The "joint learning" of the visual features and text is done in an Attention layer.
Architecture
Encoder
We will use a pre-trained InceptionResNetV2 as our encoder, or feature extractor. InceptionResNetV2 is an image classification model and by taking the output from an intermediate layer (and not the final layer) we will get the features that represents the image.
Decoder
The decoder gets as input the image features from the encoder, and the image caption. Then it will process the data thru the following layers:
- Embedding - The image caption will get embedded into a vector that will capture the "meaning" of the caption
- RNN - An RNN layer that will process the caption embeddings, the RNN holds process the words one-by-one and keeps a "memory" of previous processed words.
- Attention - The attention layer gets as input the RNN output and the image features, this is where the image data and text data are being processed together and relation between image features and text are being learned.
Then the Attention output, and the RNN output are being added together, normalized and run thru a final Dense layer to produce the next word "probabilities". It will product a "probability" for every word in the vocabulary.
Training & Inference
The training (and later inference) are pretty much similar to what was described in the Spanish-English translation encoder-decoder post.
The Code
Dependencies
!pip uninstall -y tensorflow
!pip uninstall -y tf-keras
!pip install tensorflow==2.15.1
!pip install tf-keras==2.15.1
Imports
import time
from textwrap import wrap
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow.keras import Input
from tensorflow.keras.layers import (
GRU,
Add,
AdditiveAttention,
Attention,
Concatenate,
Dense,
Embedding,
LayerNormalization,
Reshape,
StringLookup,
TextVectorization,
)
print(tf.version.VERSION)
Loading Data
IMG_HEIGHT = 299
IMG_WIDTH = 299
BUFFER_SIZE = 1000
def get_image_label(record):
# Each data_row is a dict with keys: ['captions', 'image', 'image/filename', 'image/id', 'objects']
img = record["image"]
img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
img = img / 255 # convert rgb to 0-1 range
caption = record["captions"]["text"][0] # only the first caption per image
# Add the special <start><end> tokens for the decoder to use
start = tf.convert_to_tensor("<start>")
end = tf.convert_to_tensor("<end>")
caption = tf.strings.join(
[start, caption, end], separator=" "
)
return {"image_tensor": img, "caption": caption}
# Load the dataset.
# The dataset is huge so we are downloading it to a google storage bucket.
# The bucket is located in us-central1 and if the machine is in another zone then working
# with the data will be very slow
data_dir = 'gs://asl-public/data/tensorflow_datasets/'
# Another option is to download the data locally.
# ** DID NOT WORK IN COLAB **
# The download size is 56GB!!
# If you want to download locally create the folder /content/data
# data_dir='/content/data'
train_dataset = tfds.load("coco_captions", split="train", shuffle_files=True, data_dir=data_dir)
# get only the image and caption
train_dataset = train_dataset.map(
get_image_label, num_parallel_calls=tf.data.AUTOTUNE
)
# prefetch loads the upcomfing data while current are being processed
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
Visualize
f, ax = plt.subplots(1, 4, figsize=(20, 5))
for idx, data in enumerate(train_dataset.take(4)):
ax[idx].imshow(data["image_tensor"].numpy())
caption = "\n".join(wrap(data["caption"].numpy().decode("utf-8"), 30))
ax[idx].set_title(caption)
ax[idx].axis("off")
Create a Tokenizer
MAX_CAPTION_LEN = 64
VOCAB_SIZE = 20000 # use fewer words to speed up convergence
# We will override the default standardization of TextVectorization to preserve
# "<>" characters, so we preserve the tokens for the <start> and <end>.
def standardize(inputs):
inputs = tf.strings.lower(inputs)
return tf.strings.regex_replace(
inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", ""
)
# Choose the most frequent words from the vocabulary & remove punctuation etc.
tokenizer = TextVectorization(
max_tokens=VOCAB_SIZE,
standardize=standardize,
output_sequence_length=MAX_CAPTION_LEN,
)
tokenizer.adapt(train_dataset.map(lambda x: x["caption"]))
# Lookup table: Word -> Index
word_to_index = StringLookup(
mask_token="", vocabulary=tokenizer.get_vocabulary()
)
# Lookup table: Index -> Word
index_to_word = StringLookup(
mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True
)
# tokenize the first caption (word-by-word), note the token "3" and "4" are the <start> <end> tokens
for d in train_dataset.take(1):
for w in d["caption"].numpy().decode("utf-8").split():
print(word_to_index(w))
Prepare training dataset
BATCH_SIZE = 32
def create_ds_fn(record):
img_tensor = record["image_tensor"]
caption = tokenizer(record["caption"]) # tokenize the caption
# Create the "target" training objective which is the caption without the <start> token.
target = tf.roll(caption, -1, 0) # shift left the tokens to remove the <start> token
zeros = tf.zeros([1], dtype=tf.int64)
target = tf.concat((target[:-1], zeros), axis=-1) # roll is cyclic, so the <start> token is now the last token in the tensor, replace it with 0
# The input img_tensor will go the the encoder, the caption to the decoder
# and the "target" is our training objective
return (img_tensor, caption), target
batched_ds = (
train_dataset.map(create_ds_fn)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
# Print a sample
for (img, caption), label in batched_ds.take(2):
print(f"Image shape: {img.shape}")
print(f"Caption shape: {caption.shape}")
print(f"Label shape: {label.shape}")
print(caption[0])
print(label[0])
Build the model
# InceptionResNetV2 takes (299, 299, 3) image as inputs
# note we use include_top=False meaning we don't want to include the final layer
# and what we get is the extracted features with shape (8, 8, 1536)
FEATURE_EXTRACTOR = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(
include_top=False, weights="imagenet"
)
FEATURE_EXTRACTOR.trainable = False
FEATURES_SHAPE = (8, 8, 1536)
ATTENTION_DIM = 512 # size of dense layer in Attention
IMG_CHANNELS = 3
# --- ENCODER---
## Input Layer (image)
image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), name="input_image")
# Feature extractor layer
image_features = FEATURE_EXTRACTOR(image_input)
# reshapre the features output to 2D of (64, 1536)
x = Reshape((FEATURES_SHAPE[0] * FEATURES_SHAPE[1], FEATURES_SHAPE[2]))(
image_features
)
# Dense output layer, this will be the input to the decoder's attention
encoder_output = Dense(ATTENTION_DIM, activation="relu")(x)
encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output)
# --- DECODER ---
## Input layer (image caption)
word_input = Input(shape=(MAX_CAPTION_LEN,), name="words")
## Embeddings layer
embed_x = Embedding(VOCAB_SIZE, ATTENTION_DIM)(word_input)
# RNN
decoder_gru = GRU(
ATTENTION_DIM,
return_sequences=True,
return_state=True,
)
rnn_output, rnn_state = decoder_gru(embed_x)
## Attention layer
decoder_attention = Attention()
context_vector = decoder_attention([rnn_output, encoder_output])
## Add rnn + attention
addition = Add()([rnn_output, context_vector])
## Normalization
layer_norm = LayerNormalization(axis=-1)
layer_norm_out = layer_norm(addition)
## Dense output layer
decoder_output_dense = Dense(VOCAB_SIZE)
decoder_output = decoder_output_dense(layer_norm_out)
decoder = tf.keras.Model(
inputs=[word_input, encoder_output], outputs=decoder_output
)
# --- The Model ---
image_caption_train_model = tf.keras.Model(
inputs=[image_input, word_input], outputs=decoder_output
)
image_caption_train_model.summary()
tf.keras.utils.plot_model(image_caption_train_model, show_shapes=True)
Training the model
Loss Function
All caption vectors has the same length, meaning some (if not all) captions have some 0 paddings at the end. We don't want to compute the loss on the padding part so in our custom loss function, after computing the loss for each element in the caption vector we re-compute the loss mean only on elements where we had an actual word.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="none"
)
def loss_function(real, pred):
loss_ = loss_object(real, pred)
# returns 1 to word index and 0 to padding (e.g. [1,1,1,1,1,0,0,0,0,...,0])
mask = tf.math.logical_not(tf.math.equal(real, 0))
mask = tf.cast(mask, dtype=tf.int32)
sentence_len = tf.reduce_sum(mask)
loss_ = loss_[:sentence_len]
return tf.reduce_mean(loss_, 1)
Training
%%time
image_caption_train_model.compile(
optimizer="adam",
loss=loss_function,
)
history = image_caption_train_model.fit(batched_ds, epochs=1)
Inference
Inference is a little bit different than training, during inference we will call the decoder in a loop generating words one-by-one. In order for the decoder to keep track of previously generated words, we will have to manage its RNN hidden state ("memory"). On each call to the decoder we will get the updated hidden state and provide it back to the decoder for the next iteration.
For that we build a new decoder model, note we reuse the layers we used for training as we want to use the training weights.
# This is the hidden state we will have to provide after each iteration
rnn_state_input = Input(shape=(ATTENTION_DIM,), name="gru_state_input")
# Reuse trained GRU, but update it so that it can receive states.
rnn_output, rnn_state = decoder_gru(embed_x, initial_state=rnn_state_input)
# Reuse other layers as well
context_vector = decoder_attention([rnn_output, encoder_output])
addition_output = Add()([rnn_output, context_vector])
layer_norm_output = layer_norm(addition_output)
decoder_output = decoder_output_dense(layer_norm_output)
# Define prediction Model with state input and output
decoder_pred_model = tf.keras.Model(
inputs=[word_input, rnn_state_input, encoder_output],
outputs=[decoder_output, rnn_state],
)
Prediction process
- Initialize the GRU states as zero vectors.
- Preprocess an input image, pass it to the encoder, and extract image features.
- Setup word tokens of
<start>
to start captioning. - In the for loop, we
- pass word tokens (
dec_input
), GRU states (gru_state
) and image features (features
) to the prediction decoder and get predictions (predictions
), and the updated GRU states. - select Top-K words from logits, and choose a word probabilistically so that we avoid computing softmax over VOCAB_SIZE-sized vector.
- stop predicting when the model predicts the
<end>
token. - replace the input word token with the predicted word token for the next step.
- pass word tokens (
MINIMUM_SENTENCE_LENGTH = 5
## Probabilistic prediction using the trained model
def predict_caption(filename):
rnn_state = tf.zeros((1, ATTENTION_DIM)) # Initial rnn state
# prepare image for the encoder
img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS)
img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
img = img / 255
# run encoder
features = encoder(tf.expand_dims(img, axis=0))
# initial decoder input word
decorder_input = tf.expand_dims([word_to_index("<start>")], 1)
# Keep track of the generated words
result = []
result_ids = [];
for i in range(MAX_CAPTION_LEN):
# run the decoder
predictions, rnn_state = decoder_pred_model(
[decorder_input, rnn_state, features]
)
# draws from log distribution given by predictions
top_probs, top_idxs = tf.math.top_k(
input=predictions[0][0], k=10, sorted=False
)
chosen_id = tf.random.categorical([top_probs], 1)[0].numpy()
predicted_id = top_idxs.numpy()[chosen_id][0]
# result.append(tokenizer.get_vocabulary()[predicted_id])
result.append(index_to_word(predicted_id))
result_ids.append(predicted_id)
if predicted_id == word_to_index("<end>"):
return img, result
# use the newly generated id as input for the next decoder cycle
decorder_input = tf.expand_dims([predicted_id], 1)
return img, result
Let's caption!
filename = "./sample_data/baseball.jpeg"
# Generate 5 captions
for i in range(5):
image, caption = predict_caption(filename)
print(" ".join(caption[:-1]) + ".")
img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS)
plt.imshow(img)
plt.axis("off");
we get:
a man on a plate with a bat.
a boy is riding a baseball bat on a field.
a man in uniform standing in front of a base.
a young player swinging a bat from a pitch in a crowd of spectators.
a baseball player holds up to swing the bat.
Not bad...
Top comments (0)