Supervised Image Classification#
The framework allows to pre-train image models in supervised fashion using classification datasets.
For this example usage we will be using the [BigEarthNet DataModule
](extra/DataSets and DataModules/bigearthnet.ipynb) inside a Pytorch Lightning trainer. The network will be integrated into a LightningModule
.
First we start by importing the basics we need from torch
and Pytorch Lightning that are needed to set up the LightningModule
.
# import packages
try:
import lightning.pytorch as pl
except ImportError:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import optim
from configilm import ConfigILM
Pytorch Lightning Module#
The Module
we use to encapsulate the model divides the usual loop into functions that are called internally by Pytorch Lightning. The necessary functions are just training_step
and configure_optimizer
, but to have a fully functional script, we add the validation and test steps as well as evaluation of the validation and test results. All *_step
functions are working on a single batch while *_epoch_end
functions are called after all batches are evaluated. During *_step
calls the module collects the outputs in the respective self.*_output_list
so that the results can be evaluated at the end of the epoch (this is new for Pytorch Lightning 2.0).
class LitImageEncoder(pl.LightningModule):
"""
Wrapper around a pytorch module, allowing this module to be used in automatic
training with Pytorch Lightning.
Among other things, the wrapper allows us to do automatic training and removes the
need to manage data on different devices (e.g. GPU and CPU).
"""
def __init__(
self,
config: ConfigILM.ILMConfiguration,
lr: float = 1e-3,
):
super().__init__()
self.lr = lr
self.config = config
self.model = ConfigILM.ConfigILM(config)
self.val_output_list = []
self.test_output_list = []
def training_step(self, batch, batch_idx):
x, y = batch
x_hat = self.model(x)
loss = F.binary_cross_entropy_with_logits(x_hat, y)
self.log("train/loss", loss)
return {"loss": loss}
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.01)
return optimizer
# ============== NON-MANDATORY-FUNCTION ===============
def validation_step(self, batch, batch_idx):
x, y = batch
x_hat = self.model(x)
loss = F.binary_cross_entropy_with_logits(x_hat, y)
self.val_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]
def on_validation_epoch_start(self):
super().on_validation_epoch_start()
self.val_output_list = []
def on_validation_epoch_end(self):
avg_loss = torch.stack([x["loss"] for x in self.val_output_list]).mean()
self.log("val/loss", avg_loss)
def test_step(self, batch, batch_idx):
x, y = batch
x_hat = self.model(x)
loss = F.binary_cross_entropy_with_logits(x_hat, y)
self.test_output_list += [{"loss": loss, "outputs": x_hat, "labels": y}]
def on_test_epoch_start(self):
super().on_test_epoch_start()
self.test_output_list = []
def on_test_epoch_end(self):
avg_loss = torch.stack([x["loss"] for x in self.test_output_list]).mean()
self.log("test/loss", avg_loss)
def forward(self, batch):
# because we are a wrapper, we call the inner function manually
return self.model(batch)
Configuring#
Now that we have our model, we will use the Pytorch Lightning Trainer
to run our loops. Results are logged to tensorboard
.
We start by importing some callbacks used during training
from configilm.ConfigILM import ILMConfiguration, ILMType
as well as defining our hyperparameters.
model_name = "resnet18"
seed = 42
number_of_channels = 12
image_size = 120
epochs = 4
lr = 5e-4
Then we create the configuration for usage in model creation later.
# seed for pytorch, numpy, python.random, Dataloader workers, spawned subprocesses
pl.seed_everything(seed, workers=True)
model_config = ILMConfiguration(
timm_model_name=model_name,
hf_model_name=None,
classes=19,
image_size=image_size,
channels=number_of_channels,
network_type=ILMType.IMAGE_CLASSIFICATION
)
We log the hyperparameters and create a Pytorch Lightning Trainer.
trainer = pl.Trainer(
max_epochs=epochs,
accelerator="auto",
log_every_n_steps=1,
logger=False,
)
Creating Model + Dataset#
Finally, we create the model defined above and our datamodule. We will be using a datamodule from this framework described in the Extra section.
from configilm.extra.DataModules.BENv1_DataModule import BENv1DataModule
model = LitImageEncoder(config=model_config, lr=lr)
dm = BENv1DataModule(
data_dirs=my_data_path, # path to dataset
img_size=(number_of_channels, image_size, image_size),
num_workers_dataloader=4,
)
Running#
Now we just have to call the fit()
and optionally the test()
functions.
Note
These calls generate quite a bit of output depending on the number of batches and epochs. The output is removed for readability.
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
Here is an example forward call for the model. Since the input is normalized, the colors are slightly distorted. To display the image anyway, we select only the RGB channels and normalize this image to the range 0 to 1.
Expected: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
Real: [-0.3234540820121765, 0.5756053328514099, -0.033697664737701416, -0.18205130100250244, -0.2761593461036682, -0.007394324988126755, -0.32195261120796204, -0.34879133105278015, -0.3748129606246948, -0.30978673696517944, 0.07935646921396255, -0.49131959676742554, -0.2532671093940735, -0.4226391911506653, -0.3649766147136688, -0.1662663221359253, -0.14214220643043518, -0.4286631643772125, -0.6496837139129639]