Throughput testing#

During development of new models it is often useful to test the throughput of the data pipeline. To enable this, ConfigILM provides a ThroughputTest_DataModule and a corresponding ThroughputTestDataset. These datasets don’t load any actual data, but instead generate a single dummy sample during initialization and return it for each call to __getitem__(). The fake length of the dataset can be set with the num_samples parameter.

Preparing the model#

To run the throughput test we first create the model and then pass the respective DataModule to the trainer. For more details on creating the model see the page on VQA model creation. The code here is almost identical with some reduced parts.

# import packages
try:
    import lightning.pytorch as pl
except ImportError:
    import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import optim

from configilm import ConfigILM

class LitVQAEncoder(pl.LightningModule):
    """
    Wrapper around a pytorch module, allowing this module to be used in automatic
    training with pytorch lightning.
    Among other things, the wrapper allows us to do automatic training and removes the
    need to manage data on different devices (e.g. GPU and CPU).
    """
    def __init__(
        self,
        config: ConfigILM.ILMConfiguration,
        lr: float = 1e-3,
    ):
        super().__init__()
        self.lr = lr
        self.config = config
        self.model = ConfigILM.ConfigILM(config)
        self.val_output_list = []
        self.test_output_list = []

    def _disassemble_batch(self, batch):
        images, questions, labels = batch
        # transposing tensor, needed for Huggingface-Dataloader combination
        questions = torch.tensor(
            [x.tolist() for x in questions], device=self.device
        ).T.int()
        return (images, questions), labels

    def training_step(self, batch, batch_idx):
        x, y = self._disassemble_batch(batch)
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.log("train/loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
        return optimizer

    # ============== NON-MANDATORY-FUNCTION ===============

    def validation_step(self, batch, batch_idx):
        x, y = self._disassemble_batch(batch)
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.val_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]

    def on_validation_epoch_start(self):
        super().on_validation_epoch_start()
        self.val_output_list = []

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x["loss"] for x in self.val_output_list]).mean()
        self.log("val/loss", avg_loss)

    def test_step(self, batch, batch_idx):
        x, y = self._disassemble_batch(batch)
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.test_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]

    def on_test_epoch_start(self):
        super().on_test_epoch_start()
        self.test_output_list = []

    def on_test_epoch_end(self):
        avg_loss = torch.stack([x["loss"] for x in self.test_output_list]).mean()
        self.log("test/loss", avg_loss)

    def forward(self, batch):
        # because we are a wrapper, we call the inner function manually
        return self.model(batch)

trainer = pl.Trainer(
    max_epochs=4,
    accelerator="auto",
    log_every_n_steps=1,
    logger=False,
)
from configilm.ConfigILM import ILMConfiguration, ILMType
image_model_name = "resnet18"
text_model_name = "prajjwal1/bert-tiny"
number_of_channels = 12
image_size = 120
lr = 5e-4
seq_len = 32
classes = 25

model_config = ILMConfiguration(
    timm_model_name=image_model_name,
    hf_model_name=text_model_name,
    classes=classes,
    image_size=image_size,
    channels=number_of_channels,
    network_type=ILMType.VQA_CLASSIFICATION,
    max_sequence_length=seq_len,
)
model = LitVQAEncoder(config=model_config, lr=lr)

Running the Throughput Test#

Now the model is evaluated using the ThroughputTestDataModule instead of any real data. First we create the DataModule with the desired parameters

from configilm.extra.DataModules import ThroughputTest_DataModule
dm = ThroughputTest_DataModule.VQAThroughputTestDataModule(
    data_dirs={},  # parameter is ignored but required for compatibility with other DataModules in ConfigILM
    img_size=(number_of_channels, image_size, image_size),
    seq_length=seq_len,
    num_samples=32*16,  # number of "samples" in this dataset -> each sample is the same one
    batch_size=32,
    num_classes=classes,
)

and then run the model using this fake DataModule and measure the throughput by iterating over the test set of the DataModule. We measure the throughput by measuring the time it takes to run the test set and then calculate the number of samples processed per second and related metrics.

import time
start = time.time()
trainer.test(model, datamodule=dm)
end = time.time()
print(f"Throughput: {dm.num_samples / (end - start):.3f} samples per second")
print(f"Time per sample: {(end - start) / dm.num_samples * 1000:.1f} milli-seconds")
print(f"Total time: {end - start:.3f} seconds")
Throughput: 89.759 samples per second
Time per sample: 11.1 milli-seconds
Total time: 5.704 seconds

We can compare this now with a different model configuration. For example, we can change the image model to a larger one and see how the throughput changes.

model_config = ILMConfiguration(
    timm_model_name="resnet34",
    hf_model_name=text_model_name,
    classes=classes,
    image_size=image_size,
    channels=number_of_channels,
    network_type=ILMType.VQA_CLASSIFICATION,
    max_sequence_length=seq_len,
)
model = LitVQAEncoder(config=model_config, lr=lr)
start = time.time()
trainer.test(model, datamodule=dm)
end = time.time()
print(f"Throughput: {dm.num_samples / (end - start):.3f} samples per second")
print(f"Time per sample: {(end - start) / dm.num_samples * 1000:.1f} milli-seconds")
print(f"Total time: {end - start:.3f} seconds")
Throughput: 55.302 samples per second
Time per sample: 18.1 milli-seconds
Total time: 9.258 seconds

Note

For reliable results, it is important to run the throughput test on a machine with no other significant load. The results can vary significantly depending on the hardware and the number of workers used in the DataLoader. The tests should also be run multiple times and the results averaged to get a reliable estimate of the throughput even on a quiet machine.