Spaces:
Runtime error
Runtime error
Adding the fine-tuning offline script
Browse files- src/training/prep_finetuning.py +252 -0
src/training/prep_finetuning.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Not used as part of the streamlit app, but run offline to prepare training for fine-tuning.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import stat
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from copy import copy
|
12 |
+
from random import choice, choices, shuffle
|
13 |
+
from time import time
|
14 |
+
from typing import Tuple, Generator
|
15 |
+
|
16 |
+
from src.common import data_dir, join_items_comma_and, random_true_false, pop_n
|
17 |
+
from src.datatypes import *
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(prog="prep_finetuning",
|
20 |
+
description="Fine tune a llama 2 model and push to hugging face hub for serving")
|
21 |
+
parser.add_argument('-base_model', required=True, help="The base model to use")
|
22 |
+
parser.add_argument('-products_db', required=True, help="The products sqlite to train on")
|
23 |
+
parser.add_argument('-fine_tuned_model', required=True, help="The target model name in hugging face hub")
|
24 |
+
parser.add_argument('-hf_user', required=True, help="The hugging face user to write the model to the hub")
|
25 |
+
parser.add_argument('-hf_token', required=True, help="The hugging face token to write the model to the hub")
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
|
29 |
+
class TrainingDataGenerator(ABC):
|
30 |
+
"""
|
31 |
+
Abstract class to generate fine-tuning training data. Implemented
|
32 |
+
as a generator to minimise passing around large lists unnecessarily
|
33 |
+
"""
|
34 |
+
@abstractmethod
|
35 |
+
def generate(self) -> Generator[Tuple[str, str], None, None]:
|
36 |
+
"""
|
37 |
+
Required to be implemented by the generator implementation.
|
38 |
+
:return: should yield pairs of training data as question, answer
|
39 |
+
"""
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
class CategoryDataGenerator(TrainingDataGenerator):
|
44 |
+
"""
|
45 |
+
Concrete implementation to build training data about product categories
|
46 |
+
"""
|
47 |
+
def generate(self) -> Generator[Tuple[str, str], None, None]:
|
48 |
+
# 1. First build "what do you offer" type queries
|
49 |
+
cat_names = [c.name for c in Category.all.values()]
|
50 |
+
category_synonyms = ["types", "categories", "kinds"]
|
51 |
+
for _ in range(5):
|
52 |
+
shuffle(cat_names)
|
53 |
+
cat = choice(category_synonyms)
|
54 |
+
q = f"What {cat} of products do you offer?"
|
55 |
+
a = f"ElectroHome offers {join_items_comma_and(cat_names)}."
|
56 |
+
yield q, a
|
57 |
+
|
58 |
+
# 2. Now build some "product in category" type queries
|
59 |
+
for c in Category.all.values():
|
60 |
+
prod_names = [p.name for p in c.products]
|
61 |
+
total_prod_count = len(prod_names)
|
62 |
+
for _ in range(2):
|
63 |
+
working_prod_names = copy(prod_names)
|
64 |
+
shuffle(working_prod_names)
|
65 |
+
while len(working_prod_names) > 0:
|
66 |
+
ans_product_names = pop_n(working_prod_names, 3)
|
67 |
+
q = f"What {c.name} do you have?"
|
68 |
+
a = f"We have {total_prod_count} {c.name}. For example {join_items_comma_and(ans_product_names)}. What are you lookinh for in a {c.name[:-1]}, and I can help to guide you?"
|
69 |
+
yield q, a
|
70 |
+
|
71 |
+
|
72 |
+
class ProductDescriptionDataGenerator(TrainingDataGenerator):
|
73 |
+
"""
|
74 |
+
Concrete implementation to build training data from the marketing description and price
|
75 |
+
"""
|
76 |
+
def generate(self) -> Generator[Tuple[str, str], None, None]:
|
77 |
+
for p in Product.all.values():
|
78 |
+
question_templates = [
|
79 |
+
"Tell me about the #.",
|
80 |
+
"Describe the # for me.",
|
81 |
+
"What can you tell me about the #?"
|
82 |
+
]
|
83 |
+
q = choice(question_templates).replace("#", p.name)
|
84 |
+
|
85 |
+
# Mix up the paths to include price or not and at the start/end of the response
|
86 |
+
if random_true_false():
|
87 |
+
a = p.description
|
88 |
+
else:
|
89 |
+
if random_true_false():
|
90 |
+
a = f"{p.description} It typically retails for ${p.price}."
|
91 |
+
else:
|
92 |
+
a = f"The {p.name} typically retails for ${p.price}. {p.description}"
|
93 |
+
|
94 |
+
yield q, a
|
95 |
+
|
96 |
+
|
97 |
+
class PriceDataGenerator(TrainingDataGenerator):
|
98 |
+
"""
|
99 |
+
Concrete implementation to build training data just for pricing
|
100 |
+
"""
|
101 |
+
def generate(self) -> Generator[Tuple[str, str], None, None]:
|
102 |
+
for p in Product.all.values():
|
103 |
+
question_templates = [
|
104 |
+
"How much is the #?",
|
105 |
+
"What does the # cost?",
|
106 |
+
"What's the price of the #?"
|
107 |
+
]
|
108 |
+
q = choice(question_templates).replace("#", p.name)
|
109 |
+
|
110 |
+
answer_templates = [
|
111 |
+
"It typically retails for $#.",
|
112 |
+
"Our recommended retail price is $#, but do check with your stockist.",
|
113 |
+
"The list price is $# but check in our online store if there are any offers on (www.electrohome.com)."
|
114 |
+
]
|
115 |
+
a = choice(answer_templates).replace("#", str(p.price))
|
116 |
+
|
117 |
+
yield q, a
|
118 |
+
|
119 |
+
|
120 |
+
class FeatureDataGenerator(TrainingDataGenerator):
|
121 |
+
"""
|
122 |
+
Concrete implementation to build training data just for pricing
|
123 |
+
"""
|
124 |
+
def generate(self) -> Generator[Tuple[str, str], None, None]:
|
125 |
+
# 1. First generate Q&A for what features are available by category
|
126 |
+
for c in Category.all.values():
|
127 |
+
cat_features = [f.name for f in c.features]
|
128 |
+
for _ in range(2):
|
129 |
+
working_cat_features = copy(cat_features)
|
130 |
+
shuffle(working_cat_features)
|
131 |
+
while len(working_cat_features) > 0:
|
132 |
+
some_features = pop_n(working_cat_features, 3)
|
133 |
+
feature_clause = join_items_comma_and(some_features)
|
134 |
+
|
135 |
+
question_templates = [
|
136 |
+
"What features do your # have?",
|
137 |
+
"What should I think about when I'm considering #?",
|
138 |
+
"What sort of things differentiate your #?"
|
139 |
+
]
|
140 |
+
q = choice(question_templates).replace("#", c.name)
|
141 |
+
|
142 |
+
answer_templates = [
|
143 |
+
"Our # have features like ##.",
|
144 |
+
"You might want to consider things like ## which # offer.",
|
145 |
+
"# have lots of different features, like ##."
|
146 |
+
]
|
147 |
+
a = choice(answer_templates).replace("##", feature_clause).replace("#", c.name)
|
148 |
+
|
149 |
+
yield q, a
|
150 |
+
|
151 |
+
# 2. Now generate questions the other way around - i.e. search products by feature
|
152 |
+
for f in Feature.all.values():
|
153 |
+
cat_name = f.category.name
|
154 |
+
prod_names = [p.name for p in f.products]
|
155 |
+
|
156 |
+
for _ in range(2):
|
157 |
+
working_prod_names = copy(prod_names)
|
158 |
+
while len(working_prod_names) > 0:
|
159 |
+
some_prods = pop_n(working_prod_names, 3)
|
160 |
+
if len(some_prods) > 1: # Single product examples mess up some trainind data
|
161 |
+
prod_clause = join_items_comma_and(some_prods)
|
162 |
+
|
163 |
+
q = f"Which {cat_name} offer {f.name}?"
|
164 |
+
answer_templates = [
|
165 |
+
"## are # which offer ###.",
|
166 |
+
"## have ###.",
|
167 |
+
"We have some great # which offer ### including our ##."
|
168 |
+
]
|
169 |
+
a = choice(answer_templates).replace("###", f.name).replace("##", prod_clause).replace("#", cat_name)
|
170 |
+
|
171 |
+
yield q, a
|
172 |
+
else:
|
173 |
+
q = f"Which {cat_name} offer {f.name}?"
|
174 |
+
a = f"The {some_prods[0]} has {f.name}."
|
175 |
+
yield q, a
|
176 |
+
|
177 |
+
|
178 |
+
def training_string_from_q_and_a(q: str, a: str, sys_prompt: str = None) -> str:
|
179 |
+
"""
|
180 |
+
Build the single llama formatted training string from a question
|
181 |
+
answer pair
|
182 |
+
"""
|
183 |
+
if sys_prompt is None:
|
184 |
+
sys_prompt = "You are a helpful domestic appliance advisor for the ElectroHome company. Please answer customer questions and do not mention other brands. If you cannot answer please say so."
|
185 |
+
return f"[INST] << SYS >> {sys_prompt} << SYS >> {q} [ /INST] {a}"
|
186 |
+
|
187 |
+
|
188 |
+
def fine_tuning_out_dir(out_model: str) -> str:
|
189 |
+
"""
|
190 |
+
Utility to generate the full path to the output directory, creating it if it is not there
|
191 |
+
"""
|
192 |
+
out_dir = os.path.join(data_dir, 'fine_tuning', out_model)
|
193 |
+
if not os.path.exists(out_dir):
|
194 |
+
os.makedirs(out_dir)
|
195 |
+
return out_dir
|
196 |
+
|
197 |
+
|
198 |
+
def generate_dataset(out_model: str) -> int:
|
199 |
+
"""
|
200 |
+
Coordinator to build the training data. Generates all available Q&A pairs
|
201 |
+
then formats them to llama format and saves them to the training csv file.
|
202 |
+
:return Count of lines written to the training data
|
203 |
+
"""
|
204 |
+
training_file = os.path.join(fine_tuning_out_dir(out_model), 'train.csv')
|
205 |
+
|
206 |
+
lines = []
|
207 |
+
generators = [
|
208 |
+
CategoryDataGenerator(),
|
209 |
+
ProductDescriptionDataGenerator(),
|
210 |
+
PriceDataGenerator(),
|
211 |
+
FeatureDataGenerator()
|
212 |
+
]
|
213 |
+
line_count = 0
|
214 |
+
for g in generators:
|
215 |
+
for q, a in g.generate():
|
216 |
+
line = training_string_from_q_and_a(q, a)
|
217 |
+
lines.append(line)
|
218 |
+
line_count += 1
|
219 |
+
|
220 |
+
df = pd.DataFrame(lines, columns=['text'])
|
221 |
+
df.to_csv(training_file, index=False)
|
222 |
+
return line_count
|
223 |
+
|
224 |
+
|
225 |
+
def generate_training_scripts(out_model: str, hf_user: str, hf_token: str) -> None:
|
226 |
+
"""
|
227 |
+
Generates the shell script to run to actually train the model
|
228 |
+
"""
|
229 |
+
shell_file = os.path.join(fine_tuning_out_dir(out_model), 'train.zsh')
|
230 |
+
with open(shell_file, "w") as f:
|
231 |
+
f.write("#!/bin/zsh\n\n")
|
232 |
+
f.write("# DO NOT COMMIT THIS FILE TO GIT AS IT CONTAINS THE HUGGING FACE WRITE TOKEN FOR THE REPO\n\n")
|
233 |
+
f.write('echo "STARTING TRAINING AND PUSH TO HUB"\n')
|
234 |
+
f.write("start=$(date +%s)\n")
|
235 |
+
f.write(f"autotrain llm --train --project-name {out_model} --model meta-llama/Llama-2-7b-chat-hf --data-path . --peft --lr 2e-4 --batch-size 12 --epochs 3 --trainer sft --push-to-hub --username {hf_user} --token {hf_token}\n")
|
236 |
+
f.write("end=$(date +%s)\n")
|
237 |
+
f.write('echo "TRAINING AND PUSH TOOK $(($end-$start)) seconds"')
|
238 |
+
stats = os.stat(shell_file)
|
239 |
+
os.chmod(shell_file, stats.st_mode | stat.S_IEXEC)
|
240 |
+
|
241 |
+
|
242 |
+
if __name__ == "__main__":
|
243 |
+
start = time()
|
244 |
+
if args.products_db == DataLoader.active_db:
|
245 |
+
DataLoader.load_data()
|
246 |
+
else:
|
247 |
+
DataLoader.set_db_name(args.products_db)
|
248 |
+
line_count = generate_dataset(args.fine_tuned_model)
|
249 |
+
generate_training_scripts(args.fine_tuned_model, args.hf_user, args.hf_token)
|
250 |
+
end = time()
|
251 |
+
elapsed = (int((end - start) * 10)) / 10 # round to 1dp
|
252 |
+
print(f"Generated {line_count} training examples and the training script in {elapsed} seconds.")
|