Visual Question Answering (VQA)#

The Framework is set up in way, that it is easy to combine an image model from the timm library with a language model from huggingface. For both models, either pre-trained weights can be used or the models can be trained as a composite in an end-to-end fashion. For this example usage we will be using the RSVQAxBEN DataModule which loads the RSVQAxBEN dataset published by Lobry et al. [2] inside a Pytorch Lightning trainer. The network will be integrated into a LightningModule.

First we start by importing the needed packages from torch and Pytorch Lightning so that we can set up the LightningModule.

# import packages
try:
    import lightning.pytorch as pl
except ImportError:
    import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import optim

from configilm import ConfigILM

Pytorch Lightning Module#

The Module we use to encapsulate the model divides the usual loop into functions that are called internally by Pytorch Lightning. The necessary functions are just

  1. training_step and

  2. configure_optimizer,

but to have a fully functional script, we add the validation and test steps as well as evaluation of the validation and test results. All *_step functions are working on a single batch while *_epoch_end functions are called after all batches are evaluated. During *_step calls the module collects the outputs in the respective self.*_output_list so that the results can be evaluated at the end of the epoch (this is new for Pytorch Lightning 2.0). For Visual Question Answering (VQA) we have to add one additional function, as the network works with 3 values (image + language input, output) instead of the usual 2 (input, output). Therefore we add a function (here called _disassemble_batch), which disassembles the batch into input and output where the input contains both modalities.

class LitVQAEncoder(pl.LightningModule):
    """
    Wrapper around a pytorch module, allowing this module to be used in automatic
    training with Pytorch Lightning.
    Among other things, the wrapper allows us to do automatic training and removes the
    need to manage data on different devices (e.g. GPU and CPU).
    """
    def __init__(
        self,
        config: ConfigILM.ILMConfiguration,
        lr: float = 1e-3,
    ):
        super().__init__()
        self.lr = lr
        self.config = config
        self.model = ConfigILM.ConfigILM(config)
        self.val_output_list = []
        self.test_output_list = []

    def _disassemble_batch(self, batch):
        images, questions, labels = batch
        # transposing tensor, needed for Huggingface-Dataloader combination
        questions = torch.tensor(
            [x.tolist() for x in questions], device=self.device
        ).T.int()
        return (images, questions), labels

    def training_step(self, batch, batch_idx):
        x, y = self._disassemble_batch(batch)
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.log("train/loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
        return optimizer

    # ============== NON-MANDATORY-FUNCTION ===============

    def validation_step(self, batch, batch_idx):
        x, y = self._disassemble_batch(batch)
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.val_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]

    def on_validation_epoch_start(self):
        super().on_validation_epoch_start()
        self.val_output_list = []

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x["loss"] for x in self.val_output_list]).mean()
        self.log("val/loss", avg_loss)

    def test_step(self, batch, batch_idx):
        x, y = self._disassemble_batch(batch)
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.test_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]

    def on_test_epoch_start(self):
        super().on_test_epoch_start()
        self.test_output_list = []

    def on_test_epoch_end(self):
        avg_loss = torch.stack([x["loss"] for x in self.test_output_list]).mean()
        self.log("test/loss", avg_loss)

    def forward(self, batch):
        # because we are a wrapper, we call the inner function manually
        return self.model(batch)

Configuring#

Now that we have our model, we will use the Pytorch Lightning Trainer to run our loops. Results are logged to tensorboard.

We start by importing some callbacks used during training

from configilm.ConfigILM import ILMConfiguration, ILMType

as well as defining our hyperparameters.

image_model_name = "resnet18"
text_model_name = "prajjwal1/bert-tiny"
seed = 42
number_of_channels = 12
image_size = 120
epochs = 4
lr = 5e-4

Then we create the configuration for usage in model creation later.

# seed for pytorch, numpy, python.random, Dataloader workers, spawned subprocesses
pl.seed_everything(seed, workers=True)

model_config = ILMConfiguration(
    timm_model_name=image_model_name,
    hf_model_name=text_model_name,  # different to pre-training
    classes=1000,  # different to pre-training
    image_size=image_size,
    channels=number_of_channels,
    network_type=ILMType.VQA_CLASSIFICATION  # different to pre-training
)

We log the hyperparameters and create a Pytorch Lightning Trainer.

trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="auto",
    log_every_n_steps=1,
    logger=False,
)

Creating Model + Dataset#

Finally, we create the model defined above and our datamodule. We will be using a datamodule from this framework described in the Extra section.

Note

We get a user warning here that ‘image_size’ is not known as a keyword. This is expected as most Convolutional Neural Networks (CNNs) (just as the resnet here) operate independently of the image size of the input

from configilm.extra.DataModules.RSVQAxBEN_DataModule import RSVQAxBENDataModule
model = LitVQAEncoder(config=model_config, lr=lr)
dm = RSVQAxBENDataModule(
    data_dirs=my_data_path,  # path to dataset
    img_size=(number_of_channels, image_size, image_size),
    num_workers_dataloader=4,
    tokenizer = model.model.get_tokenizer()
)
/home/runner/work/ConfigILM/ConfigILM/configilm/ConfigILM.py:134: UserWarning: Keyword 'img_size' unknown. Trying to ignore and restart creation.
  warnings.warn(f"Keyword '{failed_kw}' unknown. Trying to ignore and restart creation.")
/home/runner/work/ConfigILM/ConfigILM/configilm/ConfigILM.py:108: UserWarning: Tokenizer was initialized pretrained
  warnings.warn("Tokenizer was initialized pretrained")

Running#

Now we just have to call the fit() and optionally the test() functions.

Note

These calls generate quite a bit of output depending on the number of batches and epochs. The output is removed for readability.

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)

Here is an example forward call for the model. Since the input is normalized, the colors are slightly distorted. To display the image anyway, we select only the RGB channels and normalize this image to the range 0 to 1. The input question is already returned by the dataset in the form of tokens, so we decode here again using the tokenizer. To make the input always the same size it may be padded. For readability, one the first few tokens are shown here. Additionally only the first 10 elements of output + expected answer are shown, as the full lists have 1000 elements.

_images/6eadebde2fc04734fec9cf778fa44512be643a08fddc2b304dad17668ba77883.png
    Text: [CLS] are there some water bodies? [SEP] [PAD] [PAD] [PAD] ...
Question: [101, 2024, 2045, 2070, 2300, 4230, 1029, 102, 0, 0, 0] ...
Expected: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
          no
    Real: [-0.12006126344203949, 0.28388527035713196, -0.23709440231323242, -0.0630030706524849, 0.26145559549331665, 0.08688576519489288, 0.02884058468043804, -0.341861754655838, -0.21518073976039886, 0.12698347866535187]
          no