File size: 5,154 Bytes
ec539e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import importlib
import json
import os
from typing import List

import numpy as np
import torch
import torch.nn as nn
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    AutoConfig, AutoModelForCausalLM,
)

from utils.constants import MISTRAL_7B
from utils.utils import _get_submodules

class Cats(nn.Module):
    def __init__(
        self,
        wrapped_module: nn.Module,
        threshold: float = 0,
        hist_num_bins: int = 1000,
        hist_min: int = -1,
        hist_max: int = 1,
    ):
        super(Cats, self).__init__()
        self.wrapped_module = wrapped_module
        self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
        self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2)
        self.histogram_bins = torch.cat(
            [torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])]
        )
        self.hist_counts = torch.zeros(hist_num_bins - 1)
        self.abs_hist_counts = torch.zeros(hist_num_bins - 1)
        self.collect_stats = True

    def disable_collect_stats(self):
        self.collect_stats = False

    def enable_collect_stats(self):
        self.collect_stats = True

    def set_threshold(self, threshold: float):
        self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)

    def forward(self, x):
        x = self.wrapped_module(x)
        if self.collect_stats:
            self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0]
            self.abs_hist_counts += torch.histogram(
                torch.abs(x), bins=self.histogram_bins
            )[0]
        x[abs(x) < self.threshold] = 0
        return x


# Function to load existing data from a JSON file
def load_data(file_path):
    try:
        with open(file_path, "r") as json_file:
            return json.load(json_file)
    except FileNotFoundError:
        return {}  # Return an empty dictionary if the file does not exist


# Function to save the dictionary to a JSON file
def save_to_json(data, file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, "w") as json_file:
        json.dump(data, json_file, indent=4)


class CatsConfig(PretrainedConfig):
    model_type = "cats_model"
    def __init__(
        self,
        wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B),
        wrapped_model_class_name: str = "MistralForCausalLM",
        target_modules: List[str] = ["act_fn"],
        target_sparsity: float = 0.5,
        **kwargs,
    ):
        self.target_modules = target_modules
        self.target_sparsity = target_sparsity
        self.wrapped_model_class_name = wrapped_model_class_name
        self.__dict__.update(wrapped_model_config.__dict__)
        super().__init__(**kwargs)


class CatsModel(PreTrainedModel):
    config_class = CatsConfig

    def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs):
        super().__init__(config)
        transformers_module = importlib.import_module("transformers")
        self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name)
        self.wrapped_model = self.wrapped_model_class(config)
        if wrapped_model_pretrained_dir is not None:
            self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir)
        print(self.__dict__)
        self.inject_cats()

    def inject_cats(self):
        for name, module in self.wrapped_model.named_modules():
            parent, target, target_name = _get_submodules(self.wrapped_model, name)
            if target_name in self.config.target_modules:
                print(f"{name} is replaced.")

                # Replace target module with target module + CATS
                cats = Cats(wrapped_module=target)
                setattr(parent, target_name, cats)

    def enable_collect_stats(self):
        for module in self.wrapped_model.named_modules():
            if isinstance(module, Cats):
                module.enable_collect_stats()

    def disable_adapters(self) -> None:
        for module in self.wrapped_model.named_modules():
            if isinstance(module, Cats):
                module.disable_collect_stats()

    # def __getattr__(self, name: str):
    #     """Forward missing attributes to the wrapped module."""
    #     try:
    #         return super().__getattr__(name)  # defer to nn.Module's logic
    #     except AttributeError:
    #         return getattr(self.model, name)


def simple_exp():
    model_dir = MISTRAL_7B
    config = AutoConfig.from_pretrained(model_dir)
    cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM")
    model = CatsModel(cats_config, wrapped_model_pretrained_dir=None)
    print(model)
    print(model.wrapped_model)
    print(model.config)

    CatsConfig.register_for_auto_class()
    CatsModel.register_for_auto_class("AutoModelForCausalLM")

    repo_id = "thrunlab/cats_exp"
    model.push_to_hub(repo_id)
    model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)



if __name__ == "__main__":
    simple_exp()