Source code for alf.pretrained_models.model_adapters.lora

# Copyright (c) 2023 Horizon Robotics and Hobot Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types
import copy
from functools import partial
from typing import Callable, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import alf


[docs]class LoRA(nn.Module): r"""Base class for LoRA (Low-Rank Adaptation). For any model layer expressed as a matrix multiplication of the form :math:`h=W_0x`, it performs a reparameterization such that: .. math:: h = W_0x + \frac{\alpha}{r} BAx where :math:`A\in\mathbb{R}^{r\times k}` and :math:`B\in\mathbb{R}^{d\times r}` are the decomposition matrices and :math:`r` is the low-dim rank of the decomposition. The original LoRA paper doesn't adapt biases: "We leave the empirical investigation of adapting the MLP layers, LayerNorm layers, and biases to a future work." If a module has a weight matrix with dimensionality greater than 2 (e.g., conv), we might need to first reinterpret it as a 2D matrix ``W_0``. This reinterpretation will differ for different modules. """ def __init__(self, m: nn.Module, rank: int = 16, weight: float = 1., name: str = 'LoRA'): """ Args: m: the module to be adapted rank: the low rank weight: weight for the low-rank weight matrix """ super().__init__() for para in m.parameters(): para.requires_grad = False self._name = name self._r = rank self._alpha = weight # place ``m`` in a list to avoid pytorch recursive tracking self._m = [m] dtype = next(m.parameters()).dtype self._merged = False # LoRA weight mats self._wA = None self._wB = None # optional; only present if m.weight is huge nrows, ncols = self._decompose_weight_dims(m) # original size: nrows * ncols # adapter size: (nrows + ncols) * r # need low-rank adaptation: (nrows + ncols) * r < nrows * ncols if rank > ncols * nrows / (nrows + ncols): self._wA = nn.Parameter(torch.zeros(nrows, ncols, dtype=dtype)) else: self._wA = nn.Parameter(torch.zeros(rank, ncols, dtype=dtype)) self._wB = nn.Parameter(torch.zeros(nrows, rank, dtype=dtype)) self.reset_parameters() self._adapt(m)
[docs] def reset_parameters(self): if self._wB is None: nn.init.zeros_(self._wA) else: nn.init.xavier_uniform_(self._wA) nn.init.zeros_(self._wB)
def _adapter_weight(self): w = self._wA if self._wB is not None: w = self._wB @ w w = w.reshape(self._m[0].weight.shape) return w * self.scaling @property def scaling(self) -> float: scaling = self._alpha if self._wB is not None: # If wB is used, wB @ wA will increase the output magnitude. LoRA # compenstates this by dividing the result by ``self._r``. scaling /= self._r return scaling def _adapt(self, m: nn.Module): """Adapt a module *in place*. After this, ``m.forward()`` will be computed with the adapter weights. """ assert not hasattr(m, '_forward0'), ( "The module has already been adapted! You need to first remove the " "adapter.") forward0 = m.forward adapter_forward = self.forward def _unmerged_forward(self, input): """For unmerged forward, the base model and adapter forward separately, and their results are weighted combined. This forward way should be used for training. """ return forward0(input) + adapter_forward(input) m.forward = types.MethodType(_unmerged_forward, m) m._unmerged_forward = m.forward m._forward0 = forward0
[docs] def merge(self): """Merge the adapter and base module weights. This operation will change the base module weight in place and restore its original ``forward()``. It should only be called in the inference mode. """ if self._merged: return m = self._m[0] m.weight.data.add_(self._adapter_weight()) m.forward = m._forward0 self._merged = True
[docs] def unmerge(self): """Unmerge the adapter weights. The base module is still adapted, and the adapter weights can be trained after unmerge. """ if not self._merged: return m = self._m[0] m.weight.data.add_(-self._adapter_weight()) m.forward = m._unmerged_forward self._merged = False
[docs] def detach(self): """Detach from and recover the base module. Note that this operation is *irreversible*. Once detached, there is no way to add the adapter back to the module. We still keep the adapter weights after detachment. """ if self._merged: self.unmerge() m = self._m[0] m.forward = m._forward0 del m._forward0 del m._unmerged_forward
[docs] @classmethod def can_adapt(cls, m: nn.Module) -> bool: """Check if the adapter class can adapt a given module. """ raise NotImplementedError()
@classmethod def _decompose_weight_dims(cls, m: nn.Module) -> Tuple[int]: """Return two (reshaped) dimensions ``(nrows, ncols)`` so that ``nrows*ncols == np.prod(m.weight.shape)``. The reinterpreted shape will be used to create LoRA weights ``self._wA`` and ``self._wB``. Returns: tuple: a pair of ints ``(nrows, ncols)``. """ raise NotImplementedError()
[docs] def forward(self, input): raise NotImplementedError()
[docs]@alf.configurable class EmbeddingAdapter(LoRA): """Adapter for embedding layers. """
[docs] @classmethod def can_adapt(cls, m: nn.Module) -> bool: return isinstance(m, nn.Embedding)
@classmethod def _decompose_weight_dims(cls, m: nn.Module): return m.num_embeddings, m.embedding_dim
[docs] def forward(self, input): m = self._m[0] if self._wB is not None: embedding_table = self._wB else: embedding_table = self._wA input = F.embedding(input, embedding_table, m.padding_idx, m.max_norm, m.norm_type, m.scale_grad_by_freq, m.sparse) if self._wB is not None: input = input @ self._wA return input * self.scaling
[docs]@alf.configurable class LinearAdapter(LoRA): """Adapter for linear layers. """
[docs] @classmethod def can_adapt(cls, m: nn.Module) -> bool: return isinstance(m, nn.Linear)
@classmethod def _decompose_weight_dims(cls, m: nn.Module): return m.out_features, m.in_features
[docs] def forward(self, input): input = input @ self._wA.t() if self._wB is not None: input = input @ self._wB.t() return input * self.scaling
[docs]@alf.configurable class Conv2dAdapter(LoRA): """Adapter for Conv2d layers. The most natural way of a LoRA decomposition for Conv2d is .. code-block:: python (rank, kernel_size[0] * kernel_size[1] * in_channels) x (out_channels, rank) However, since ``out_channels`` is usually small, this decomposition is not low-rank actually and it won't save much memory. We can first reinterpret the weight matrix as a shape of .. code-block:: python (out_channels * kernel_size[0], in_channels * kernel_size[1]) to make the in- and out-dimensions balanced, and then decompose. """
[docs] @classmethod def can_adapt(cls, m: nn.Module) -> bool: return isinstance(m, nn.Conv2d)
@classmethod def _decompose_weight_dims(cls, m: nn.Module): kernel_size = m.kernel_size # conv2d weight has a shape: ``(out_channels, in_channels // groups, *kernel_size)`` nrows = m.out_channels * kernel_size[0] ncols = m.in_channels // m.groups * kernel_size[1] return nrows, ncols
[docs] def forward(self, input): m = self._m[0] if isinstance(m.padding, str) or m.groups > 1 or self._wB is None: # Three senarios when two-stage convolution is difficult: # 1. m.groups > 1: r has to be divisible by m.groups in order to # preserve the correct input-output mapping # 2. m.padding is a string: torch will compute paddings on the fly, # so we won't know its values in advance. # 3. low-rank decomposition is not available return F.conv2d( input, self._adapter_weight(), stride=m.stride, padding=m.padding, dilation=m.dilation, groups=m.groups) else: input = F.conv2d( input, self._wA.reshape(self._r, m.in_channels, 1, m.kernel_size[1]), stride=(1, m.stride[1]), padding=(0, m.padding[1]), dilation=(1, m.dilation[1])) output = F.conv2d( input, self._wB.reshape(m.out_channels, self._r, m.kernel_size[0], 1), stride=(m.stride[0], 1), padding=(m.padding[0], 0), dilation=(m.dilation[0], 1)) return output * self.scaling
def _adapter_weight(self): m = self._m[0] if self._wB is None: w = self._wA.reshape(m.weight.shape) else: # This weight tensor has to be consistent with the two-stage conv in # ``self.forward()`` wa = self._wA.reshape(self._r, m.in_channels // m.groups, m.kernel_size[1]) wb = self._wB.reshape(m.out_channels, self._r, m.kernel_size[0]) w = torch.einsum('rik,org->oigk', wa, wb) return w * self.scaling