File size: 5,154 Bytes
ec539e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import importlib
import json
import os
from typing import List
import numpy as np
import torch
import torch.nn as nn
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig, AutoModelForCausalLM,
)
from utils.constants import MISTRAL_7B
from utils.utils import _get_submodules
class Cats(nn.Module):
def __init__(
self,
wrapped_module: nn.Module,
threshold: float = 0,
hist_num_bins: int = 1000,
hist_min: int = -1,
hist_max: int = 1,
):
super(Cats, self).__init__()
self.wrapped_module = wrapped_module
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2)
self.histogram_bins = torch.cat(
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])]
)
self.hist_counts = torch.zeros(hist_num_bins - 1)
self.abs_hist_counts = torch.zeros(hist_num_bins - 1)
self.collect_stats = True
def disable_collect_stats(self):
self.collect_stats = False
def enable_collect_stats(self):
self.collect_stats = True
def set_threshold(self, threshold: float):
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
def forward(self, x):
x = self.wrapped_module(x)
if self.collect_stats:
self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0]
self.abs_hist_counts += torch.histogram(
torch.abs(x), bins=self.histogram_bins
)[0]
x[abs(x) < self.threshold] = 0
return x
# Function to load existing data from a JSON file
def load_data(file_path):
try:
with open(file_path, "r") as json_file:
return json.load(json_file)
except FileNotFoundError:
return {} # Return an empty dictionary if the file does not exist
# Function to save the dictionary to a JSON file
def save_to_json(data, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as json_file:
json.dump(data, json_file, indent=4)
class CatsConfig(PretrainedConfig):
model_type = "cats_model"
def __init__(
self,
wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B),
wrapped_model_class_name: str = "MistralForCausalLM",
target_modules: List[str] = ["act_fn"],
target_sparsity: float = 0.5,
**kwargs,
):
self.target_modules = target_modules
self.target_sparsity = target_sparsity
self.wrapped_model_class_name = wrapped_model_class_name
self.__dict__.update(wrapped_model_config.__dict__)
super().__init__(**kwargs)
class CatsModel(PreTrainedModel):
config_class = CatsConfig
def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs):
super().__init__(config)
transformers_module = importlib.import_module("transformers")
self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name)
self.wrapped_model = self.wrapped_model_class(config)
if wrapped_model_pretrained_dir is not None:
self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir)
print(self.__dict__)
self.inject_cats()
def inject_cats(self):
for name, module in self.wrapped_model.named_modules():
parent, target, target_name = _get_submodules(self.wrapped_model, name)
if target_name in self.config.target_modules:
print(f"{name} is replaced.")
# Replace target module with target module + CATS
cats = Cats(wrapped_module=target)
setattr(parent, target_name, cats)
def enable_collect_stats(self):
for module in self.wrapped_model.named_modules():
if isinstance(module, Cats):
module.enable_collect_stats()
def disable_adapters(self) -> None:
for module in self.wrapped_model.named_modules():
if isinstance(module, Cats):
module.disable_collect_stats()
# def __getattr__(self, name: str):
# """Forward missing attributes to the wrapped module."""
# try:
# return super().__getattr__(name) # defer to nn.Module's logic
# except AttributeError:
# return getattr(self.model, name)
def simple_exp():
model_dir = MISTRAL_7B
config = AutoConfig.from_pretrained(model_dir)
cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM")
model = CatsModel(cats_config, wrapped_model_pretrained_dir=None)
print(model)
print(model.wrapped_model)
print(model.config)
CatsConfig.register_for_auto_class()
CatsModel.register_for_auto_class("AutoModelForCausalLM")
repo_id = "thrunlab/cats_exp"
model.push_to_hub(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
if __name__ == "__main__":
simple_exp()
|