Supervised Image Classification#

The framework allows to pre-train image models in supervised fashion using classification datasets. For this example usage we will be using the [BigEarthNet DataModule](extra/DataSets and DataModules/bigearthnet.ipynb) inside a Pytorch Lightning trainer. The network will be integrated into a LightningModule.

First we start by importing the basics we need from torch and Pytorch Lightning that are needed to set up the LightningModule.

# import packages
try:
    import lightning.pytorch as pl
except ImportError:
    import pytorch_lightning as pl
    
import torch
import torch.nn.functional as F
from torch import optim

from configilm import ConfigILM

Pytorch Lightning Module#

The Module we use to encapsulate the model divides the usual loop into functions that are called internally by Pytorch Lightning. The necessary functions are just training_step and configure_optimizer, but to have a fully functional script, we add the validation and test steps as well as evaluation of the validation and test results. All *_step functions are working on a single batch while *_epoch_end functions are called after all batches are evaluated. During *_step calls the module collects the outputs in the respective self.*_output_list so that the results can be evaluated at the end of the epoch (this is new for Pytorch Lightning 2.0).

class LitImageEncoder(pl.LightningModule):
    """
    Wrapper around a pytorch module, allowing this module to be used in automatic
    training with Pytorch Lightning.
    Among other things, the wrapper allows us to do automatic training and removes the
    need to manage data on different devices (e.g. GPU and CPU).
    """
    def __init__(
        self,
        config: ConfigILM.ILMConfiguration,
        lr: float = 1e-3,
    ):
        super().__init__()
        self.lr = lr
        self.config = config
        self.model = ConfigILM.ConfigILM(config)
        self.val_output_list = []
        self.test_output_list = []


    def training_step(self, batch, batch_idx):
        x, y = batch
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.log("train/loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
        return optimizer

    # ============== NON-MANDATORY-FUNCTION ===============

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.val_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]

    def on_validation_epoch_start(self):
        super().on_validation_epoch_start()
        self.val_output_list = []

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x["loss"] for x in self.val_output_list]).mean()
        self.log("val/loss", avg_loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        x_hat = self.model(x)
        loss = F.binary_cross_entropy_with_logits(x_hat, y)
        self.test_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]

    def on_test_epoch_start(self):
        super().on_test_epoch_start()
        self.test_output_list = []

    def on_test_epoch_end(self):
        avg_loss = torch.stack([x["loss"] for x in self.test_output_list]).mean()
        self.log("test/loss", avg_loss)

    def forward(self, batch):
        # because we are a wrapper, we call the inner function manually
        return self.model(batch)

Configuring#

Now that we have our model, we will use the Pytorch Lightning Trainer to run our loops. Results are logged to tensorboard.

We start by importing some callbacks used during training

from configilm.ConfigILM import ILMConfiguration, ILMType

as well as defining our hyperparameters.

model_name = "resnet18"
seed = 42
number_of_channels = 12
image_size = 120
epochs = 4
lr = 5e-4

Then we create the configuration for usage in model creation later.

# seed for pytorch, numpy, python.random, Dataloader workers, spawned subprocesses
pl.seed_everything(seed, workers=True)

model_config = ILMConfiguration(
    timm_model_name=model_name,
    hf_model_name=None,
    classes=19,
    image_size=image_size,
    channels=number_of_channels,
    network_type=ILMType.IMAGE_CLASSIFICATION
)

We log the hyperparameters and create a Pytorch Lightning Trainer.

trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="auto",
    log_every_n_steps=1,
    logger=False,
)

Creating Model + Dataset#

Finally, we create the model defined above and our datamodule. We will be using a datamodule from this framework described in the Extra section.

from configilm.extra.DataModules.BENv1_DataModule import BENv1DataModule
model = LitImageEncoder(config=model_config, lr=lr)
dm = BENv1DataModule(
    data_dirs=my_data_path,  # path to dataset
    img_size=(number_of_channels, image_size, image_size),
    num_workers_dataloader=4,
)

Running#

Now we just have to call the fit() and optionally the test() functions.

Note

These calls generate quite a bit of output depending on the number of batches and epochs. The output is removed for readability.

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)

Here is an example forward call for the model. Since the input is normalized, the colors are slightly distorted. To display the image anyway, we select only the RGB channels and normalize this image to the range 0 to 1.

_images/ac0b1e81a7813d07af0b0c4d5af654eb167ef9bcc44c21a17eca62bd5a50c8b2.png
Expected: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
    Real: [-0.3234540820121765, 0.5756053328514099, -0.033697664737701416, -0.18205130100250244, -0.2761593461036682, -0.007394324988126755, -0.32195261120796204, -0.34879133105278015, -0.3748129606246948, -0.30978673696517944, 0.07935646921396255, -0.49131959676742554, -0.2532671093940735, -0.4226391911506653, -0.3649766147136688, -0.1662663221359253, -0.14214220643043518, -0.4286631643772125, -0.6496837139129639]