import torch from torch import nn def replace_linear_with_lora( module: nn.Module, max_rank: int, scale: float = 1.0, ) -> None: for name, child in module.named_children(): if isinstance(child, nn.Linear): new_lora = LinearLora( in_features=child.in_features, out_features=child.out_features, bias=child.bias, rank=max_rank, scale=scale, dtype=child.weight.dtype, device=child.weight.device, ) new_lora.weight = child.weight new_lora.bias = child.bias if child.bias is not None else None setattr(module, name, new_lora) else: replace_linear_with_lora( module=child, max_rank=max_rank, scale=scale, ) class LinearLora(nn.Linear): def __init__( self, in_features: int, out_features: int, bias: bool, rank: int, dtype: torch.dtype, device: torch.device, lora_bias: bool = True, scale: float = 1.0, *args, **kwargs, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias is not None, device=device, dtype=dtype, *args, **kwargs, ) assert isinstance(scale, float), "scale must be a float" self.scale = scale self.rank = rank self.lora_bias = lora_bias self.dtype = dtype self.device = device if rank > (new_rank := min(self.out_features, self.in_features)): self.rank = new_rank self.lora_A = nn.Linear( in_features=in_features, out_features=self.rank, bias=False, dtype=dtype, device=device, ) self.lora_B = nn.Linear( in_features=self.rank, out_features=out_features, bias=self.lora_bias, dtype=dtype, device=device, ) def set_scale(self, scale: float) -> None: assert isinstance(scale, float), "scalar value must be a float" self.scale = scale def forward(self, input: torch.Tensor) -> torch.Tensor: base_out = super().forward(input) _lora_out_B = self.lora_B(self.lora_A(input)) lora_update = _lora_out_B * self.scale return base_out + lora_update