Model Configuration#
The central object of ConfigILM
is the dataclass ILMConfiguration
. This is used to decide which parts the model consists of, how it is combined, and which task it should ultimately solve. A possible minimal configuration for Supervised Classification can look like this:
Note
Not all properties of the object are always used. Which properties are unused depends on the network type specified. For classification there is no fusion or language part, therefore in this example all parameters associated with fusion or language modeling are unused.
from configilm.ConfigILM import ILMConfiguration, ILMType
from pprint import pprint
model_config = ILMConfiguration(
timm_model_name="resnet18",
)
All parameters as well as their respective default values can be seen below.
Parameter name | Type | Default value
----------------------------------------------------------------------------------------------------------------------------------------
timm_model_name | <class 'str'> | <REQUIRED PARAM>
hf_model_name | typing.Optional[str] | None
image_size | <class 'int'> | 120
channels | <class 'int'> | 3
classes | <class 'int'> | 10
class_names | typing.Optional[typing.Sequence[str]] | None
network_type | <enum 'ILMType'> | IMAGE_CLASSIFICATION (value: 0)
visual_features_out | <class 'int'> | 512
fusion_in | <class 'int'> | 512
fusion_out | typing.Optional[int] | None
fusion_hidden | <class 'int'> | 256
v_dropout_rate | <class 'float'> | 0.25
t_dropout_rate | <class 'float'> | 0.25
fusion_dropout_rate | <class 'float'> | 0.25
fusion_method | typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | torch.mul
fusion_activation | typing.Callable[[torch.Tensor], torch.Tensor] | Tanh()
drop_rate | typing.Optional[float] | 0.2
use_pooler_output | <class 'bool'> | True
max_sequence_length | <class 'int'> | 32
load_pretrained_timm_if_available | <class 'bool'> | False
load_pretrained_hf_if_available | <class 'bool'> | True
This class is used to ultimately create the model, but also collects all other information such as the image size in an object. This facilitates the organization in the code and is to prevent that there are many global variables. Currently, the configuration supports the following network types:
['IMAGE_CLASSIFICATION', 'VQA_CLASSIFICATION']