alfraser commited on
Commit
08aae17
·
1 Parent(s): f87d5b8

Adding the fine-tuning offline script

Browse files
Files changed (1) hide show
  1. src/training/prep_finetuning.py +252 -0
src/training/prep_finetuning.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Not used as part of the streamlit app, but run offline to prepare training for fine-tuning.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import stat
8
+ import pandas as pd
9
+
10
+ from abc import ABC, abstractmethod
11
+ from copy import copy
12
+ from random import choice, choices, shuffle
13
+ from time import time
14
+ from typing import Tuple, Generator
15
+
16
+ from src.common import data_dir, join_items_comma_and, random_true_false, pop_n
17
+ from src.datatypes import *
18
+
19
+ parser = argparse.ArgumentParser(prog="prep_finetuning",
20
+ description="Fine tune a llama 2 model and push to hugging face hub for serving")
21
+ parser.add_argument('-base_model', required=True, help="The base model to use")
22
+ parser.add_argument('-products_db', required=True, help="The products sqlite to train on")
23
+ parser.add_argument('-fine_tuned_model', required=True, help="The target model name in hugging face hub")
24
+ parser.add_argument('-hf_user', required=True, help="The hugging face user to write the model to the hub")
25
+ parser.add_argument('-hf_token', required=True, help="The hugging face token to write the model to the hub")
26
+ args = parser.parse_args()
27
+
28
+
29
+ class TrainingDataGenerator(ABC):
30
+ """
31
+ Abstract class to generate fine-tuning training data. Implemented
32
+ as a generator to minimise passing around large lists unnecessarily
33
+ """
34
+ @abstractmethod
35
+ def generate(self) -> Generator[Tuple[str, str], None, None]:
36
+ """
37
+ Required to be implemented by the generator implementation.
38
+ :return: should yield pairs of training data as question, answer
39
+ """
40
+ pass
41
+
42
+
43
+ class CategoryDataGenerator(TrainingDataGenerator):
44
+ """
45
+ Concrete implementation to build training data about product categories
46
+ """
47
+ def generate(self) -> Generator[Tuple[str, str], None, None]:
48
+ # 1. First build "what do you offer" type queries
49
+ cat_names = [c.name for c in Category.all.values()]
50
+ category_synonyms = ["types", "categories", "kinds"]
51
+ for _ in range(5):
52
+ shuffle(cat_names)
53
+ cat = choice(category_synonyms)
54
+ q = f"What {cat} of products do you offer?"
55
+ a = f"ElectroHome offers {join_items_comma_and(cat_names)}."
56
+ yield q, a
57
+
58
+ # 2. Now build some "product in category" type queries
59
+ for c in Category.all.values():
60
+ prod_names = [p.name for p in c.products]
61
+ total_prod_count = len(prod_names)
62
+ for _ in range(2):
63
+ working_prod_names = copy(prod_names)
64
+ shuffle(working_prod_names)
65
+ while len(working_prod_names) > 0:
66
+ ans_product_names = pop_n(working_prod_names, 3)
67
+ q = f"What {c.name} do you have?"
68
+ a = f"We have {total_prod_count} {c.name}. For example {join_items_comma_and(ans_product_names)}. What are you lookinh for in a {c.name[:-1]}, and I can help to guide you?"
69
+ yield q, a
70
+
71
+
72
+ class ProductDescriptionDataGenerator(TrainingDataGenerator):
73
+ """
74
+ Concrete implementation to build training data from the marketing description and price
75
+ """
76
+ def generate(self) -> Generator[Tuple[str, str], None, None]:
77
+ for p in Product.all.values():
78
+ question_templates = [
79
+ "Tell me about the #.",
80
+ "Describe the # for me.",
81
+ "What can you tell me about the #?"
82
+ ]
83
+ q = choice(question_templates).replace("#", p.name)
84
+
85
+ # Mix up the paths to include price or not and at the start/end of the response
86
+ if random_true_false():
87
+ a = p.description
88
+ else:
89
+ if random_true_false():
90
+ a = f"{p.description} It typically retails for ${p.price}."
91
+ else:
92
+ a = f"The {p.name} typically retails for ${p.price}. {p.description}"
93
+
94
+ yield q, a
95
+
96
+
97
+ class PriceDataGenerator(TrainingDataGenerator):
98
+ """
99
+ Concrete implementation to build training data just for pricing
100
+ """
101
+ def generate(self) -> Generator[Tuple[str, str], None, None]:
102
+ for p in Product.all.values():
103
+ question_templates = [
104
+ "How much is the #?",
105
+ "What does the # cost?",
106
+ "What's the price of the #?"
107
+ ]
108
+ q = choice(question_templates).replace("#", p.name)
109
+
110
+ answer_templates = [
111
+ "It typically retails for $#.",
112
+ "Our recommended retail price is $#, but do check with your stockist.",
113
+ "The list price is $# but check in our online store if there are any offers on (www.electrohome.com)."
114
+ ]
115
+ a = choice(answer_templates).replace("#", str(p.price))
116
+
117
+ yield q, a
118
+
119
+
120
+ class FeatureDataGenerator(TrainingDataGenerator):
121
+ """
122
+ Concrete implementation to build training data just for pricing
123
+ """
124
+ def generate(self) -> Generator[Tuple[str, str], None, None]:
125
+ # 1. First generate Q&A for what features are available by category
126
+ for c in Category.all.values():
127
+ cat_features = [f.name for f in c.features]
128
+ for _ in range(2):
129
+ working_cat_features = copy(cat_features)
130
+ shuffle(working_cat_features)
131
+ while len(working_cat_features) > 0:
132
+ some_features = pop_n(working_cat_features, 3)
133
+ feature_clause = join_items_comma_and(some_features)
134
+
135
+ question_templates = [
136
+ "What features do your # have?",
137
+ "What should I think about when I'm considering #?",
138
+ "What sort of things differentiate your #?"
139
+ ]
140
+ q = choice(question_templates).replace("#", c.name)
141
+
142
+ answer_templates = [
143
+ "Our # have features like ##.",
144
+ "You might want to consider things like ## which # offer.",
145
+ "# have lots of different features, like ##."
146
+ ]
147
+ a = choice(answer_templates).replace("##", feature_clause).replace("#", c.name)
148
+
149
+ yield q, a
150
+
151
+ # 2. Now generate questions the other way around - i.e. search products by feature
152
+ for f in Feature.all.values():
153
+ cat_name = f.category.name
154
+ prod_names = [p.name for p in f.products]
155
+
156
+ for _ in range(2):
157
+ working_prod_names = copy(prod_names)
158
+ while len(working_prod_names) > 0:
159
+ some_prods = pop_n(working_prod_names, 3)
160
+ if len(some_prods) > 1: # Single product examples mess up some trainind data
161
+ prod_clause = join_items_comma_and(some_prods)
162
+
163
+ q = f"Which {cat_name} offer {f.name}?"
164
+ answer_templates = [
165
+ "## are # which offer ###.",
166
+ "## have ###.",
167
+ "We have some great # which offer ### including our ##."
168
+ ]
169
+ a = choice(answer_templates).replace("###", f.name).replace("##", prod_clause).replace("#", cat_name)
170
+
171
+ yield q, a
172
+ else:
173
+ q = f"Which {cat_name} offer {f.name}?"
174
+ a = f"The {some_prods[0]} has {f.name}."
175
+ yield q, a
176
+
177
+
178
+ def training_string_from_q_and_a(q: str, a: str, sys_prompt: str = None) -> str:
179
+ """
180
+ Build the single llama formatted training string from a question
181
+ answer pair
182
+ """
183
+ if sys_prompt is None:
184
+ sys_prompt = "You are a helpful domestic appliance advisor for the ElectroHome company. Please answer customer questions and do not mention other brands. If you cannot answer please say so."
185
+ return f"[INST] << SYS >> {sys_prompt} << SYS >> {q} [ /INST] {a}"
186
+
187
+
188
+ def fine_tuning_out_dir(out_model: str) -> str:
189
+ """
190
+ Utility to generate the full path to the output directory, creating it if it is not there
191
+ """
192
+ out_dir = os.path.join(data_dir, 'fine_tuning', out_model)
193
+ if not os.path.exists(out_dir):
194
+ os.makedirs(out_dir)
195
+ return out_dir
196
+
197
+
198
+ def generate_dataset(out_model: str) -> int:
199
+ """
200
+ Coordinator to build the training data. Generates all available Q&A pairs
201
+ then formats them to llama format and saves them to the training csv file.
202
+ :return Count of lines written to the training data
203
+ """
204
+ training_file = os.path.join(fine_tuning_out_dir(out_model), 'train.csv')
205
+
206
+ lines = []
207
+ generators = [
208
+ CategoryDataGenerator(),
209
+ ProductDescriptionDataGenerator(),
210
+ PriceDataGenerator(),
211
+ FeatureDataGenerator()
212
+ ]
213
+ line_count = 0
214
+ for g in generators:
215
+ for q, a in g.generate():
216
+ line = training_string_from_q_and_a(q, a)
217
+ lines.append(line)
218
+ line_count += 1
219
+
220
+ df = pd.DataFrame(lines, columns=['text'])
221
+ df.to_csv(training_file, index=False)
222
+ return line_count
223
+
224
+
225
+ def generate_training_scripts(out_model: str, hf_user: str, hf_token: str) -> None:
226
+ """
227
+ Generates the shell script to run to actually train the model
228
+ """
229
+ shell_file = os.path.join(fine_tuning_out_dir(out_model), 'train.zsh')
230
+ with open(shell_file, "w") as f:
231
+ f.write("#!/bin/zsh\n\n")
232
+ f.write("# DO NOT COMMIT THIS FILE TO GIT AS IT CONTAINS THE HUGGING FACE WRITE TOKEN FOR THE REPO\n\n")
233
+ f.write('echo "STARTING TRAINING AND PUSH TO HUB"\n')
234
+ f.write("start=$(date +%s)\n")
235
+ f.write(f"autotrain llm --train --project-name {out_model} --model meta-llama/Llama-2-7b-chat-hf --data-path . --peft --lr 2e-4 --batch-size 12 --epochs 3 --trainer sft --push-to-hub --username {hf_user} --token {hf_token}\n")
236
+ f.write("end=$(date +%s)\n")
237
+ f.write('echo "TRAINING AND PUSH TOOK $(($end-$start)) seconds"')
238
+ stats = os.stat(shell_file)
239
+ os.chmod(shell_file, stats.st_mode | stat.S_IEXEC)
240
+
241
+
242
+ if __name__ == "__main__":
243
+ start = time()
244
+ if args.products_db == DataLoader.active_db:
245
+ DataLoader.load_data()
246
+ else:
247
+ DataLoader.set_db_name(args.products_db)
248
+ line_count = generate_dataset(args.fine_tuned_model)
249
+ generate_training_scripts(args.fine_tuned_model, args.hf_user, args.hf_token)
250
+ end = time()
251
+ elapsed = (int((end - start) * 10)) / 10 # round to 1dp
252
+ print(f"Generated {line_count} training examples and the training script in {elapsed} seconds.")