JadenFK's picture
Added cusotm models and refactor of layout
5349660
raw
history blame
2.77 kB
import copy
import re
import torch
import util
class FineTunedModel(torch.nn.Module):
def __init__(self,
model,
modules,
frozen_modules=[]
):
super().__init__()
if isinstance(modules, str):
modules = [modules]
self.model = model
self.ft_modules = {}
self.orig_modules = {}
util.freeze(self.model)
for module_name, module in model.named_modules():
for ft_module_regex in modules:
match = re.search(ft_module_regex, module_name)
if match is not None:
ft_module = copy.deepcopy(module)
self.orig_modules[module_name] = module
self.ft_modules[module_name] = ft_module
util.unfreeze(ft_module)
print(f"=> Finetuning {module_name}")
for ft_module_name, module in ft_module.named_modules():
ft_module_name = f"{module_name}.{ft_module_name}"
for freeze_module_name in frozen_modules:
match = re.search(freeze_module_name, ft_module_name)
if match:
print(f"=> Freezing {ft_module_name}")
util.freeze(module)
self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())
@classmethod
def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
if isinstance(checkpoint, str):
checkpoint = torch.load(checkpoint)
modules = [f"{key}$" for key in list(checkpoint.keys())]
ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
ftm.load_state_dict(checkpoint)
return ftm
def __enter__(self):
for key, ft_module in self.ft_modules.items():
util.set_module(self.model, key, ft_module)
def __exit__(self, exc_type, exc_value, tb):
for key, module in self.orig_modules.items():
util.set_module(self.model, key, module)
def parameters(self):
parameters = []
for ft_module in self.ft_modules.values():
parameters.extend(list(ft_module.parameters()))
return parameters
def state_dict(self):
state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
return state_dict
def load_state_dict(self, state_dict):
for key, sd in state_dict.items():
self.ft_modules[key].load_state_dict(sd)