Nanobit commited on
Commit
96e8378
·
1 Parent(s): e9650d3

Delete extract_lora.py

Browse files
Files changed (1) hide show
  1. scripts/extract_lora.py +0 -163
scripts/extract_lora.py DELETED
@@ -1,163 +0,0 @@
1
- # import logging
2
- # import os
3
- # import random
4
- # import signal
5
- # import sys
6
- # from pathlib import Path
7
-
8
- # import fire
9
- # import torch
10
- # import yaml
11
- # from addict import Dict
12
-
13
- # from peft import set_peft_model_state_dict, get_peft_model_state_dict
14
-
15
- # # add src to the pythonpath so we don't need to pip install this
16
- # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
17
- # src_dir = os.path.join(project_root, "src")
18
- # sys.path.insert(0, src_dir)
19
-
20
- # from axolotl.utils.data import load_prepare_datasets
21
- # from axolotl.utils.models import load_model
22
- # from axolotl.utils.trainer import setup_trainer
23
- # from axolotl.utils.wandb import setup_wandb_env_vars
24
-
25
- # logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
26
-
27
-
28
- # def choose_device(cfg):
29
- # def get_device():
30
- # if torch.cuda.is_available():
31
- # return "cuda"
32
- # else:
33
- # try:
34
- # if torch.backends.mps.is_available():
35
- # return "mps"
36
- # except:
37
- # return "cpu"
38
-
39
- # cfg.device = get_device()
40
- # if cfg.device == "cuda":
41
- # cfg.device_map = {"": cfg.local_rank}
42
- # else:
43
- # cfg.device_map = {"": cfg.device}
44
-
45
-
46
- # def choose_config(path: Path):
47
- # yaml_files = [file for file in path.glob("*.yml")]
48
-
49
- # if not yaml_files:
50
- # raise ValueError(
51
- # "No YAML config files found in the specified directory. Are you using a .yml extension?"
52
- # )
53
-
54
- # print("Choose a YAML file:")
55
- # for idx, file in enumerate(yaml_files):
56
- # print(f"{idx + 1}. {file}")
57
-
58
- # chosen_file = None
59
- # while chosen_file is None:
60
- # try:
61
- # choice = int(input("Enter the number of your choice: "))
62
- # if 1 <= choice <= len(yaml_files):
63
- # chosen_file = yaml_files[choice - 1]
64
- # else:
65
- # print("Invalid choice. Please choose a number from the list.")
66
- # except ValueError:
67
- # print("Invalid input. Please enter a number.")
68
-
69
- # return chosen_file
70
-
71
-
72
- # def save_latest_checkpoint_as_lora(
73
- # config: Path = Path("configs/"),
74
- # prepare_ds_only: bool = False,
75
- # **kwargs,
76
- # ):
77
- # if Path(config).is_dir():
78
- # config = choose_config(config)
79
-
80
- # # load the config from the yaml file
81
- # with open(config, "r") as f:
82
- # cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader))
83
- # # if there are any options passed in the cli, if it is something that seems valid from the yaml,
84
- # # then overwrite the value
85
- # cfg_keys = dict(cfg).keys()
86
- # for k in kwargs:
87
- # if k in cfg_keys:
88
- # # handle booleans
89
- # if isinstance(cfg[k], bool):
90
- # cfg[k] = bool(kwargs[k])
91
- # else:
92
- # cfg[k] = kwargs[k]
93
-
94
- # # setup some derived config / hyperparams
95
- # cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
96
- # cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
97
- # cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
98
- # assert cfg.local_rank == 0, "Run this with only one device!"
99
-
100
- # choose_device(cfg)
101
- # cfg.ddp = False
102
-
103
- # if cfg.device == "mps":
104
- # cfg.load_in_8bit = False
105
- # cfg.tf32 = False
106
- # if cfg.bf16:
107
- # cfg.fp16 = True
108
- # cfg.bf16 = False
109
-
110
- # # Load the model and tokenizer
111
- # logging.info("loading model, tokenizer, and lora_config...")
112
- # model, tokenizer, lora_config = load_model(
113
- # cfg.base_model,
114
- # cfg.base_model_config,
115
- # cfg.model_type,
116
- # cfg.tokenizer_type,
117
- # cfg,
118
- # adapter=cfg.adapter,
119
- # inference=True,
120
- # )
121
-
122
- # model.config.use_cache = False
123
-
124
- # if torch.__version__ >= "2" and sys.platform != "win32":
125
- # logging.info("Compiling torch model")
126
- # model = torch.compile(model)
127
-
128
- # possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
129
- # if len(possible_checkpoints) > 0:
130
- # sorted_paths = sorted(
131
- # possible_checkpoints, key=lambda path: int(path.split("-")[-1])
132
- # )
133
- # resume_from_checkpoint = sorted_paths[-1]
134
- # else:
135
- # raise FileNotFoundError("Checkpoints folder not found")
136
-
137
- # pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin")
138
-
139
- # assert os.path.exists(pytorch_bin_path), "Bin not found"
140
-
141
- # logging.info(f"Loading {pytorch_bin_path}")
142
- # adapters_weights = torch.load(pytorch_bin_path, map_location="cpu")
143
-
144
- # # d = get_peft_model_state_dict(model)
145
- # print(model.load_state_dict(adapters_weights))
146
- # # with open('b.log', "w") as f:
147
- # # f.write(str(d.keys()))
148
- # assert False
149
-
150
- # print((adapters_weights.keys()))
151
- # with open("a.log", "w") as f:
152
- # f.write(str(adapters_weights.keys()))
153
- # assert False
154
-
155
- # logging.info("Setting peft model state dict")
156
- # set_peft_model_state_dict(model, adapters_weights)
157
-
158
- # logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}")
159
- # model.save_pretrained(cfg.output_dir)
160
-
161
-
162
- # if __name__ == "__main__":
163
- # fire.Fire(save_latest_checkpoint_as_lora)