llm-arch / src /data_synthesis /test_question_generator.py
alfraser's picture
Updated from using random.choices to random.sample throughout where I need a random distinct set as choices does replacement so you can get the same item twice. Discovered in pricing testing.
b897a48
raw
history blame
3.21 kB
"""
Script to generate test questions from the data
"""
import os
import json
import sys
from random import choice, sample, randint
from typing import Dict
from src.common import join_items_comma_and, data_dir
from src.datatypes import *
category_key = "#CATEGORY#"
product_name_key = "#PRODUCT_NAME#"
features_key = "#FEATURE_STRING#"
characteristic_key = "#CHARACTERISTIC#"
def get_random_template() -> str:
templates = [
f"I'm looking for a {characteristic_key} {category_key}. What would you recommend?",
f"What features should I think about when buying a {category_key}?",
f"What {category_key} would you recommend that has {features_key}?",
f"What can you tell me about the {product_name_key} {category_key}?",
f"How much is a the ElectroHome {product_name_key}?",
f"What's your best rated {category_key}?",
f"What {category_key}s does ElectroHome offer?",
f"Is the {product_name_key} a good {category_key}?",
f"I'm considering buying a new {category_key}. What should I think about and what models would you recommend?",
f"Tell me about your {category_key}s."
]
return choice(templates)
def get_random_values() -> Dict[str, str]:
"""
Generates a random set of entries for a question
"""
category: Category = choice(list(Category.all.values()))
category_name = category.name[:-1].lower()
if category_name == "tv":
category_name = "TV"
product: Product = choice(category.products)
features: List[Feature] = sample(category.features, k=randint(1, 4))
characteristic: str = choice([
"big",
"durable",
"budget",
"family",
"compact",
"sleek"
])
return {
category_key: category_name,
product_name_key: product.name,
features_key: join_items_comma_and([f.name for f in features]),
characteristic_key: characteristic
}
def merge_template_values(template: str, values: Dict[str, str]) -> str:
"""
Returns a template with the values substituted in
"""
for k, v in values.items():
template = template.replace(k, v)
return template
def generate_question() -> str:
"""
Generate a single random question
"""
DataLoader.load_data()
template = get_random_template()
values = get_random_values()
return merge_template_values(template=template, values=values)
def save_questions_to_json(questions: List[str]) -> None:
"""
Persist the questions into the file
"""
data = {'questions': questions}
questions_file = os.path.join(data_dir, 'json', 'test_questions.json')
with open(questions_file, 'w') as f:
json.dump(data, fp=f, indent=2)
def generate_questions(n: int = 100):
"""
Generate n questions and overwrite the test_questions.json question bank
"""
questions = [generate_question() for _ in range(n)]
save_questions_to_json(questions=questions)
if __name__ == "__main__":
"""
Run from the command line with the number of questions to generate
"""
try:
generate_questions(int(sys.argv[1]))
except IndexError:
generate_questions()