|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
from typing import List, Any, Optional, Type |
|
|
|
|
|
class enable_lora: |
|
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: |
|
self.activated: bool = activated |
|
if activated: |
|
return |
|
self.lora_modules: List[BaseTunerLayer] = [ |
|
each for each in lora_modules if isinstance(each, BaseTunerLayer) |
|
] |
|
self.scales = [ |
|
{ |
|
active_adapter: lora_module.scaling[active_adapter] |
|
for active_adapter in lora_module.active_adapters |
|
} |
|
for lora_module in self.lora_modules |
|
] |
|
|
|
def __enter__(self) -> None: |
|
if self.activated: |
|
return |
|
|
|
for lora_module in self.lora_modules: |
|
if not isinstance(lora_module, BaseTunerLayer): |
|
continue |
|
lora_module.scale_layer(0) |
|
|
|
def __exit__( |
|
self, |
|
exc_type: Optional[Type[BaseException]], |
|
exc_val: Optional[BaseException], |
|
exc_tb: Optional[Any], |
|
) -> None: |
|
if self.activated: |
|
return |
|
for i, lora_module in enumerate(self.lora_modules): |
|
if not isinstance(lora_module, BaseTunerLayer): |
|
continue |
|
for active_adapter in lora_module.active_adapters: |
|
lora_module.scaling[active_adapter] = self.scales[i][active_adapter] |
|
|
|
|
|
class set_lora_scale: |
|
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: |
|
self.lora_modules: List[BaseTunerLayer] = [ |
|
each for each in lora_modules if isinstance(each, BaseTunerLayer) |
|
] |
|
self.scales = [ |
|
{ |
|
active_adapter: lora_module.scaling[active_adapter] |
|
for active_adapter in lora_module.active_adapters |
|
} |
|
for lora_module in self.lora_modules |
|
] |
|
self.scale = scale |
|
|
|
def __enter__(self) -> None: |
|
for lora_module in self.lora_modules: |
|
if not isinstance(lora_module, BaseTunerLayer): |
|
continue |
|
lora_module.scale_layer(self.scale) |
|
|
|
def __exit__( |
|
self, |
|
exc_type: Optional[Type[BaseException]], |
|
exc_val: Optional[BaseException], |
|
exc_tb: Optional[Any], |
|
) -> None: |
|
for i, lora_module in enumerate(self.lora_modules): |
|
if not isinstance(lora_module, BaseTunerLayer): |
|
continue |
|
for active_adapter in lora_module.active_adapters: |
|
lora_module.scaling[active_adapter] = self.scales[i][active_adapter] |
|
|