khoicrtp commited on
Commit
d417650
1 Parent(s): dbf3b86

Upload 13 files

Browse files
finetune_lora.py CHANGED
@@ -28,7 +28,7 @@ learning_rate = 3e-4
28
  batch_size = 128
29
  micro_batch_size = 4
30
  gradient_accumulation_steps = batch_size // micro_batch_size
31
- max_iters = 2 #50000 * 3 // micro_batch_size
32
  weight_decay = 0.0
33
  max_seq_length = 256 # see scripts/prepare_alpaca.py
34
  lora_r = 8
@@ -44,7 +44,7 @@ def main(
44
  ):
45
 
46
  #fabric = L.Fabric(accelerator="cuda", precision="bf16-true")
47
- fabric = L.Fabric(accelerator="cpu", devices=2, precision="bf16-true")
48
  fabric.launch()
49
  fabric.seed_everything(1337 + fabric.global_rank)
50
 
 
28
  batch_size = 128
29
  micro_batch_size = 4
30
  gradient_accumulation_steps = batch_size // micro_batch_size
31
+ max_iters = 10000 #50000 * 3 // micro_batch_size
32
  weight_decay = 0.0
33
  max_seq_length = 256 # see scripts/prepare_alpaca.py
34
  lora_r = 8
 
44
  ):
45
 
46
  #fabric = L.Fabric(accelerator="cuda", precision="bf16-true")
47
+ fabric = L.Fabric(accelerator="cpu", devices=1, precision="bf16-true")
48
  fabric.launch()
49
  fabric.seed_everything(1337 + fabric.global_rank)
50
 
finetune_lora_origin.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction-tuning with LoRA on the Alpaca dataset.
3
+
4
+ Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
5
+ `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
6
+ """
7
+ import sys
8
+ from pathlib import Path
9
+ import os
10
+ import time
11
+
12
+ import lightning as L
13
+ import numpy as np
14
+ import torch
15
+
16
+ # support running without installing as a package
17
+ wd = Path(__file__).parent.parent.resolve()
18
+ sys.path.append(str(wd))
19
+
20
+ from generate import generate
21
+ from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
22
+ from lit_llama.model import LLaMA, LLaMAConfig
23
+ from lit_llama.tokenizer import Tokenizer
24
+ from scripts.prepare_alpaca import generate_prompt
25
+
26
+
27
+ eval_interval = 100
28
+ save_interval = 100
29
+ eval_iters = 100
30
+ log_interval = 1
31
+
32
+ # Hyperparameters
33
+ learning_rate = 3e-4
34
+ batch_size = 128
35
+ micro_batch_size = 4
36
+ gradient_accumulation_steps = batch_size // micro_batch_size
37
+ max_iters = 50000 * 3 // micro_batch_size
38
+ weight_decay = 0.0
39
+ max_seq_length = 256 # see scripts/prepare_alpaca.py
40
+ lora_r = 8
41
+ lora_alpha = 16
42
+ lora_dropout = 0.05
43
+ warmup_steps = 100
44
+
45
+
46
+ def main(
47
+ data_dir: str = "data/alpaca",
48
+ pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
49
+ out_dir: str = "out/lora/alpaca",
50
+ ):
51
+
52
+ fabric = L.Fabric(accelerator="cpu", devices=1, precision="bf16-true")
53
+ # fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true")
54
+ fabric.launch()
55
+ fabric.seed_everything(1337 + fabric.global_rank)
56
+
57
+ if fabric.global_rank == 0:
58
+ os.makedirs(out_dir, exist_ok=True)
59
+
60
+ train_data, val_data = load_datasets(data_dir=data_dir)
61
+
62
+ config = LLaMAConfig.from_name("7B")
63
+ config.block_size = max_seq_length
64
+
65
+ checkpoint = torch.load(pretrained_path)
66
+
67
+ with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
68
+ model = LLaMA(config)
69
+ # strict=False because missing keys due to LoRA weights not contained in checkpoint state
70
+ model.load_state_dict(checkpoint, strict=False)
71
+
72
+ mark_only_lora_as_trainable(model)
73
+
74
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
75
+ model, optimizer = fabric.setup(model, optimizer)
76
+ train(fabric, model, optimizer, train_data, val_data, out_dir)
77
+
78
+ # Save the final LoRA checkpoint at the end of training
79
+ checkpoint = lora_state_dict(model)
80
+ fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
81
+
82
+
83
+ def train(
84
+ fabric: L.Fabric,
85
+ model: torch.nn.Module,
86
+ optimizer: torch.optim.Optimizer,
87
+ train_data: np.ndarray,
88
+ val_data: np.ndarray,
89
+ out_dir: str,
90
+ ) -> None:
91
+ """The training loop.
92
+
93
+ Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
94
+ """
95
+ step_count = 0
96
+
97
+ for iter_num in range(max_iters):
98
+
99
+ if step_count <= warmup_steps:
100
+ # linear warmup
101
+ lr = learning_rate * step_count / warmup_steps
102
+ for param_group in optimizer.param_groups:
103
+ param_group['lr'] = lr
104
+
105
+ t0 = time.time()
106
+
107
+ input_ids, targets = get_batch(fabric, train_data)
108
+ logits = model(input_ids)
109
+ loss = loss_fn(logits, targets)
110
+ fabric.backward(loss)
111
+
112
+ if (iter_num + 1) % gradient_accumulation_steps == 0:
113
+ optimizer.step()
114
+ optimizer.zero_grad()
115
+ step_count += 1
116
+
117
+ if step_count % eval_interval == 0:
118
+ val_loss = validate(fabric, model, val_data)
119
+ fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
120
+ fabric.barrier()
121
+
122
+ if step_count % save_interval == 0:
123
+ print(f"Saving LoRA weights to {out_dir}")
124
+ # We are only saving the LoRA weights
125
+ # TODO: Provide a function/script to merge the LoRA weights with pretrained weights
126
+ checkpoint = lora_state_dict(model)
127
+ fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
128
+
129
+ dt = time.time() - t0
130
+ if iter_num % log_interval == 0:
131
+ fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
132
+
133
+
134
+ def generate_response(model, instruction):
135
+ tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
136
+ sample = {"instruction": instruction, "input": ""}
137
+ prompt = generate_prompt(sample)
138
+ encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
139
+
140
+ output = generate(
141
+ model,
142
+ idx=encoded,
143
+ max_seq_length=max_seq_length,
144
+ max_new_tokens=100,
145
+ )
146
+ output = tokenizer.decode(output)
147
+ return output # output.split("### Response:")[1].strip()
148
+
149
+
150
+ @torch.no_grad()
151
+ def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
152
+ fabric.print("Validating ...")
153
+ model.eval()
154
+ losses = torch.zeros(eval_iters)
155
+ for k in range(eval_iters):
156
+ input_ids, targets = get_batch(fabric, val_data)
157
+ logits = model(input_ids)
158
+ loss = loss_fn(logits, targets)
159
+ losses[k] = loss.item()
160
+ out = losses.mean()
161
+
162
+ # produce an example:
163
+ instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
164
+
165
+ output = generate_response(model, instruction)
166
+ fabric.print(instruction)
167
+ fabric.print(output)
168
+
169
+ model.train()
170
+ return out.item()
171
+
172
+ def loss_fn(logits, targets):
173
+ # shift the targets such that output n predicts token n+1
174
+ logits = logits[..., :-1, :].contiguous()
175
+ targets = targets[..., 1:].contiguous()
176
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
177
+ return loss
178
+
179
+
180
+ def get_batch(fabric: L.Fabric, data: list):
181
+ ix = torch.randint(len(data), (micro_batch_size,))
182
+
183
+ input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
184
+ labels = [data[i]["labels"].type(torch.int64) for i in ix]
185
+
186
+ max_len = max(len(s) for s in input_ids)
187
+
188
+ def pad_right(x, pad_id):
189
+ # pad right based on the longest sequence
190
+ n = max_len - len(x)
191
+ return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
192
+
193
+ x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
194
+ y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
195
+ x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
196
+ return x, y
197
+
198
+
199
+ def load_datasets(data_dir):
200
+ train_data = torch.load(os.path.join(data_dir, "train.pt"))
201
+ val_data = torch.load(os.path.join(data_dir, "test.pt"))
202
+ return train_data, val_data
203
+
204
+
205
+ if __name__ == "__main__":
206
+ # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
207
+ # torch.backends.cuda.enable_flash_sdp(False)
208
+ torch.set_float32_matmul_precision("high")
209
+
210
+ from jsonargparse.cli import CLI
211
+
212
+ CLI(main)
scripts/__pycache__/prepare_alpaca.cpython-311.pyc CHANGED
Binary files a/scripts/__pycache__/prepare_alpaca.cpython-311.pyc and b/scripts/__pycache__/prepare_alpaca.cpython-311.pyc differ
 
scripts/prepare_alpaca.py CHANGED
@@ -22,8 +22,8 @@ IGNORE_INDEX = -1
22
  def prepare(
23
  destination_path: Path = Path("data/alpaca"),
24
  tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
25
- #test_split_size: int = 2000,
26
- test_split_size: int = 2,
27
  max_seq_length: int = 256,
28
  seed: int = 42,
29
  mask_inputs: bool = False, # as in alpaca-lora
 
22
  def prepare(
23
  destination_path: Path = Path("data/alpaca"),
24
  tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
25
+ test_split_size: int = 200,
26
+ #test_split_size: int = 2,
27
  max_seq_length: int = 256,
28
  seed: int = 42,
29
  mask_inputs: bool = False, # as in alpaca-lora
scripts/prepare_alpaca_origin.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation derived from https://github.com/tloen/alpaca-lora"""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # support running without installing as a package
6
+ wd = Path(__file__).parent.parent.resolve()
7
+ sys.path.append(str(wd))
8
+
9
+ import torch
10
+ import requests
11
+ import json
12
+ from torch.utils.data import random_split
13
+ from lit_llama.tokenizer import Tokenizer
14
+ from tqdm import tqdm
15
+
16
+
17
+ DATA_FILE = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"
18
+ DATA_FILE_NAME = "alpaca_data_cleaned_archive.json"
19
+ IGNORE_INDEX = -1
20
+
21
+
22
+ def prepare(
23
+ destination_path: Path = Path("data/alpaca"),
24
+ tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
25
+ test_split_size: int = 2000,
26
+ max_seq_length: int = 256,
27
+ seed: int = 42,
28
+ mask_inputs: bool = False, # as in alpaca-lora
29
+ data_file_name: str = DATA_FILE_NAME
30
+ ) -> None:
31
+ """Prepare the Alpaca dataset for instruction tuning.
32
+
33
+ The output is a training and validation dataset saved as `train.pt` and `val.pt`,
34
+ which stores the preprocessed and tokenized prompts and labels.
35
+ """
36
+
37
+ destination_path.mkdir(parents=True, exist_ok=True)
38
+ file_path = destination_path / data_file_name
39
+ download(file_path)
40
+
41
+ # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
42
+ tokenizer = Tokenizer(tokenizer_path)
43
+
44
+ with open(file_path, "r") as file:
45
+ data = json.load(file)
46
+
47
+ # Partition the dataset into train and test
48
+ train_split_size = len(data) - test_split_size
49
+ train_set, test_set = random_split(
50
+ data,
51
+ lengths=(train_split_size, test_split_size),
52
+ generator=torch.Generator().manual_seed(seed),
53
+ )
54
+ train_set, test_set = list(train_set), list(test_set)
55
+
56
+ print(f"train has {len(train_set):,} samples")
57
+ print(f"val has {len(test_set):,} samples")
58
+
59
+ print("Processing train split ...")
60
+ train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
61
+ torch.save(train_set, file_path.parent / "train.pt")
62
+
63
+ print("Processing test split ...")
64
+ test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
65
+ torch.save(test_set, file_path.parent / "test.pt")
66
+
67
+
68
+ def download(file_path: Path):
69
+ """Downloads the raw json data file and saves it in the given destination."""
70
+ if file_path.exists():
71
+ return
72
+ with open(file_path, "w") as f:
73
+ f.write(requests.get(DATA_FILE).text)
74
+
75
+
76
+ def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
77
+ """Processes a single sample.
78
+
79
+ Each sample in the dataset consists of:
80
+ - instruction: A string describing the task
81
+ - input: A string holding a special input value for the instruction.
82
+ This only applies to some samples, and in others this is empty.
83
+ - output: The response string
84
+
85
+ This function processes this data to produce a prompt text and a label for
86
+ supervised training. The prompt text is formed as a single message including both
87
+ the instruction and the input. The label/target is the same message but with the
88
+ response attached.
89
+
90
+ Finally, both the prompt and the label get tokenized. If desired, all tokens
91
+ in the label that correspond to the original input prompt get masked out (default).
92
+ """
93
+ full_prompt = generate_prompt(example)
94
+ full_prompt_and_response = full_prompt + example["output"]
95
+ encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
96
+ encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
97
+
98
+ # The labels are the full prompt with response, but with the prompt masked out
99
+ labels = encoded_full_prompt_and_response.clone()
100
+ if mask_inputs:
101
+ labels[:len(encoded_full_prompt)] = IGNORE_INDEX
102
+
103
+ return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
104
+
105
+
106
+ def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
107
+ return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
108
+
109
+
110
+ def generate_prompt(example):
111
+ """Generates a standardized message to prompt the model with an instruction, optional input and a
112
+ 'response' field."""
113
+
114
+ if example["input"]:
115
+ return (
116
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
117
+ "Write a response that appropriately completes the request.\n\n"
118
+ f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
119
+ )
120
+ return (
121
+ "Below is an instruction that describes a task. "
122
+ "Write a response that appropriately completes the request.\n\n"
123
+ f"### Instruction:\n{example['instruction']}\n\n### Response:"
124
+ )
125
+
126
+
127
+ if __name__ == "__main__":
128
+ from jsonargparse import CLI
129
+
130
+ CLI(prepare)