Customize Fusion Approach#

ConfigILM supports a variety of fusion approaches. In this notebook, we will demonstrate how to customize the fusion approach by implementing a custom fusion approach and how to use it with ConfigILM as well as how to use the fusion approaches provided by ConfigILM.

Fusion Approaches provided by ConfigILM#

ConfigILM provides a wide range of fusion approaches. Details of the fusion approaches provided by ConfigILM can be found here. All methods are implemented as nn.Module with a harmonized interface. Depending on the fusion approach, the input to the fusion module can be a list of tensors or a single tensor. The output of the fusion module is a single tensor. An example of a fusion module is shown below:

import torch
from configilm.Fusion.TuckerFusion import Tucker

fusion_dim = 30
output_dim = 10

fusion = Tucker(
    input_dims=[fusion_dim, fusion_dim],
    output_dim=output_dim,
    mm_dim=25,
)

t1 = torch.randn(fusion_dim)
t2 = torch.randn(fusion_dim)
output = fusion(input_0=t1, input_1=t2)
assert output.shape == (output_dim,)

This fusion module takes two input tensors with dimensions 20 and 30, respectively, and produces a single output tensor with dimension 10. The mm_dim parameter specifies the dimension of the middle mode in the Tucker decomposition.

This fusion can be used with ConfigILM as shown below:

from configilm.ConfigILM import ConfigILM, ILMConfiguration, ILMType

configurations = ILMConfiguration(
    timm_model_name="resnet18",
    hf_model_name="prajjwal1/bert-tiny",
    fusion_in=fusion_dim,
    fusion_out=output_dim,
    custom_fusion_method=("Tucker", fusion),
    network_type=ILMType.VQA_CLASSIFICATION,
)
model = ConfigILM(configurations)

v = torch.rand((3, 3, 224, 224))
q = torch.randint(0, 1000, (3, 10), dtype=torch.long)
output = model((v, q))
assert output.shape == (3, 10)
/home/runner/work/ConfigILM/ConfigILM/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/home/runner/work/ConfigILM/ConfigILM/configilm/ConfigILM.py:134: UserWarning: Keyword 'img_size' unknown. Trying to ignore and restart creation.
  warnings.warn(f"Keyword '{failed_kw}' unknown. Trying to ignore and restart creation.")
/home/runner/work/ConfigILM/ConfigILM/configilm/ConfigILM.py:81: UserWarning: Model 'prajjwal1/bert-tiny' not available. Trying to download...

  warnings.warn(f"Model '{model_name}' not available. Trying to download...\n")
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/home/runner/work/ConfigILM/ConfigILM/configilm/ConfigILM.py:108: UserWarning: Tokenizer was initialized pretrained
  warnings.warn("Tokenizer was initialized pretrained")

Custom Fusion Approach#

To implement a custom fusion approach, you need to create a callable method, e.g. a subclass of nn.Module. The forward method should take the input tensors and return the output tensor. An example of a custom fusion module is shown below:

import torch

class CustomFusion(torch.nn.Module):
    def __init__(self):
        super(CustomFusion, self).__init__()
        self.operation = torch.mul

    def forward(self, input_0, input_1):
        return self.operation(input_0, input_1)

custom_fusion = CustomFusion()
assert custom_fusion(torch.tensor(2), torch.tensor(3)) == 6

This custom fusion module takes two input tensors and multiplies them element-wise. To use the custom fusion module with ConfigILM, you need to provide the custom fusion module as a tuple with the name of the custom fusion module and the fusion definition. The custom fusion module can be used with ConfigILM as shown below.

Note

It is important that the name of the custom fusion module is exactly the same as when the custom fusion module is defined. This is needed for importing the custom fusion module.

configurations = ILMConfiguration(
    timm_model_name="resnet18",
    hf_model_name="prajjwal1/bert-tiny",
    fusion_in=fusion_dim,
    fusion_out=fusion_dim,
    custom_fusion_method=("CustomFusion", custom_fusion),
    network_type=ILMType.VQA_CLASSIFICATION,
)
model = ConfigILM(configurations)

v = torch.rand((3, 3, 224, 224))
q = torch.randint(0, 1000, (3, 10), dtype=torch.long)
output = model((v, q))
assert output.shape == (3, 10)

model.config.fusion_method
CustomFusion()
# using a different name won't work
configurations = ILMConfiguration(
    timm_model_name="resnet18",
    hf_model_name="prajjwal1/bert-tiny",
    fusion_in=fusion_dim,
    fusion_out=fusion_dim,
    custom_fusion_method=("DifferentName", custom_fusion),
    network_type=ILMType.VQA_CLASSIFICATION,
)
model = ConfigILM(configurations)

v = torch.rand((3, 3, 224, 224))
q = torch.randint(0, 1000, (3, 10), dtype=torch.long)
try:
    output = model((v, q))
except NameError as e:
    print("NameError occurred")
NameError occurred

Using a correct name for the custom fusion module is important. If the name is not correct, the wrong fusion module will be used. This may work but will not provide the expected results.

In the result below, we can see that the CustomFusion module is used for fusion, even tho we expected DifferentCustomFusion to be used.

class DifferentCustomFusion(torch.nn.Module):
    def __init__(self):
        super(DifferentCustomFusion, self).__init__()
        self.operation = torch.add

    def forward(self, input_0, input_1):
        return self.operation(input_0, input_1)
    
different_custom_fusion = DifferentCustomFusion()
configurations = ILMConfiguration(
    timm_model_name="resnet18",
    hf_model_name="prajjwal1/bert-tiny",
    fusion_in=fusion_dim,
    fusion_out=fusion_dim,
    custom_fusion_method=("CustomFusion", different_custom_fusion),
    network_type=ILMType.VQA_CLASSIFICATION,
)

model = ConfigILM(configurations)

v = torch.rand((3, 3, 224, 224))
q = torch.randint(0, 1000, (3, 10), dtype=torch.long)
output = model((v, q))
assert output.shape == (3, 10)

model.config.fusion_method
CustomFusion()

The customised fusion does not have to be a subclass, it can also be a function. The function should take the input tensors and return the output tensor. An example of a custom fusion function is shown below:

def custom_fusion_function(input_0, input_1):
    return input_0 + input_1

assert custom_fusion_function(torch.tensor(2), torch.tensor(3)) == 5

configurations = ILMConfiguration(
    timm_model_name="resnet18",
    hf_model_name="prajjwal1/bert-tiny",
    fusion_in=fusion_dim,
    fusion_out=fusion_dim,
    custom_fusion_method=("custom_fusion_function", custom_fusion_function),
    network_type=ILMType.VQA_CLASSIFICATION,
)
model = ConfigILM(configurations)

v = torch.rand((3, 3, 224, 224))
q = torch.randint(0, 1000, (3, 10), dtype=torch.long)
output = model((v, q))
assert output.shape == (3, 10)

model.config.fusion_method
<function __main__.custom_fusion_function(input_0, input_1)>

or it can be any number of already existing functions.

configurations = ILMConfiguration(
    timm_model_name="resnet18",
    hf_model_name="prajjwal1/bert-tiny",
    fusion_in=fusion_dim,
    fusion_out=fusion_dim,
    custom_fusion_method=("torch.mul", torch.mul),
    network_type=ILMType.VQA_CLASSIFICATION,
)
model = ConfigILM(configurations)

v = torch.rand((3, 3, 224, 224))
q = torch.randint(0, 1000, (3, 10), dtype=torch.long)
output = model((v, q))
assert output.shape == (3, 10)

model.config.fusion_method
<function torch._VariableFunctionsClass.mul>