alf.pretrained_models#

alf.pretrained_models.pretrained_model#

class PretrainedModel(model, adapter_cls=[], module_blacklist=None, module_whitelist=None, name='PretrainedModel')[source]#

Bases: torch.nn.modules.module.Module

A wrapper class for managing pretrained models.

A pretrained model is generally large and its weights will always be frozen. For finetuning, we can add small adapters to some of its layers and only the adapter weights will be trained on downstream tasks. See https://docs.adapterhub.ml/methods.html.

Parameters
  • model (Module) – the base pretrained model whose weights will be used as frozen

  • adapter_cls (List[Callable]) – an optional list of adapter classes applied to the base model layers. An adapter instance of each class will be created for each of all qualified layers. The adapter weights will be stored in the adapter instance.

  • module_blacklist (Optional[List[str]]) – an optional blacklist of modules not to be adapted. Each entry can be a regex or a substring of the module name.

  • module_whitelist (Optional[List[str]]) – an optional whitelist of modules not to be adapted. Each entry can be a regex or a substring of the module name. By default this is None which means all modules are valid. Only at most one of module_blacklist and module_whitelist can be provided.

  • name (str) – name of the pretrained model

property adapted_module_names: List[str]#

Return a list of adapted module names, in the adapter adding order.

Return type

List[str]

forward(input)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

merge_adapter()[source]#

Merge adapter weights into the model for efficient inference.

Note that even after merging, when we save a checkpoint for this pretrained model, we still only save the adapter weights only. In other words, whether the adapters are merged or not is transparent to pytorch’s checkpointing.

property model: torch.nn.modules.module.Module#

Return the base model.

Return type

Module

remove_adapter()[source]#

Remove the adapter (if existed).

After the removal, forward() will only use the frozen weights. This operation is irreversible as the adapter can no longer be added back.

Returns

a list of the adapters, in case their weights are needed.

Return type

nn.ModueList

reset_adapter()[source]#

Reset the adapter weights.

training: bool#
unmerge_adapter()[source]#

Unmerge adapter weights to enable training.