cats_exp / cats.py
vxbrandon's picture
Upload model
ec539e9 verified
raw
history blame
5.15 kB
import importlib
import json
import os
from typing import List
import numpy as np
import torch
import torch.nn as nn
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig, AutoModelForCausalLM,
)
from utils.constants import MISTRAL_7B
from utils.utils import _get_submodules
class Cats(nn.Module):
def __init__(
self,
wrapped_module: nn.Module,
threshold: float = 0,
hist_num_bins: int = 1000,
hist_min: int = -1,
hist_max: int = 1,
):
super(Cats, self).__init__()
self.wrapped_module = wrapped_module
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2)
self.histogram_bins = torch.cat(
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])]
)
self.hist_counts = torch.zeros(hist_num_bins - 1)
self.abs_hist_counts = torch.zeros(hist_num_bins - 1)
self.collect_stats = True
def disable_collect_stats(self):
self.collect_stats = False
def enable_collect_stats(self):
self.collect_stats = True
def set_threshold(self, threshold: float):
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
def forward(self, x):
x = self.wrapped_module(x)
if self.collect_stats:
self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0]
self.abs_hist_counts += torch.histogram(
torch.abs(x), bins=self.histogram_bins
)[0]
x[abs(x) < self.threshold] = 0
return x
# Function to load existing data from a JSON file
def load_data(file_path):
try:
with open(file_path, "r") as json_file:
return json.load(json_file)
except FileNotFoundError:
return {} # Return an empty dictionary if the file does not exist
# Function to save the dictionary to a JSON file
def save_to_json(data, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as json_file:
json.dump(data, json_file, indent=4)
class CatsConfig(PretrainedConfig):
model_type = "cats_model"
def __init__(
self,
wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B),
wrapped_model_class_name: str = "MistralForCausalLM",
target_modules: List[str] = ["act_fn"],
target_sparsity: float = 0.5,
**kwargs,
):
self.target_modules = target_modules
self.target_sparsity = target_sparsity
self.wrapped_model_class_name = wrapped_model_class_name
self.__dict__.update(wrapped_model_config.__dict__)
super().__init__(**kwargs)
class CatsModel(PreTrainedModel):
config_class = CatsConfig
def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs):
super().__init__(config)
transformers_module = importlib.import_module("transformers")
self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name)
self.wrapped_model = self.wrapped_model_class(config)
if wrapped_model_pretrained_dir is not None:
self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir)
print(self.__dict__)
self.inject_cats()
def inject_cats(self):
for name, module in self.wrapped_model.named_modules():
parent, target, target_name = _get_submodules(self.wrapped_model, name)
if target_name in self.config.target_modules:
print(f"{name} is replaced.")
# Replace target module with target module + CATS
cats = Cats(wrapped_module=target)
setattr(parent, target_name, cats)
def enable_collect_stats(self):
for module in self.wrapped_model.named_modules():
if isinstance(module, Cats):
module.enable_collect_stats()
def disable_adapters(self) -> None:
for module in self.wrapped_model.named_modules():
if isinstance(module, Cats):
module.disable_collect_stats()
# def __getattr__(self, name: str):
# """Forward missing attributes to the wrapped module."""
# try:
# return super().__getattr__(name) # defer to nn.Module's logic
# except AttributeError:
# return getattr(self.model, name)
def simple_exp():
model_dir = MISTRAL_7B
config = AutoConfig.from_pretrained(model_dir)
cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM")
model = CatsModel(cats_config, wrapped_model_pretrained_dir=None)
print(model)
print(model.wrapped_model)
print(model.config)
CatsConfig.register_for_auto_class()
CatsModel.register_for_auto_class("AutoModelForCausalLM")
repo_id = "thrunlab/cats_exp"
model.push_to_hub(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
if __name__ == "__main__":
simple_exp()