Spaces:
Runtime error
Runtime error
""" | |
Script to generate test questions from the data | |
""" | |
import os | |
import json | |
import sys | |
from random import choice, sample, randint | |
from typing import Dict | |
from src.common import join_items_comma_and, data_dir | |
from src.datatypes import * | |
category_key = "#CATEGORY#" | |
product_name_key = "#PRODUCT_NAME#" | |
features_key = "#FEATURE_STRING#" | |
characteristic_key = "#CHARACTERISTIC#" | |
def get_random_template() -> str: | |
templates = [ | |
f"I'm looking for a {characteristic_key} {category_key}. What would you recommend?", | |
f"What features should I think about when buying a {category_key}?", | |
f"What {category_key} would you recommend that has {features_key}?", | |
f"What can you tell me about the {product_name_key} {category_key}?", | |
f"How much is a the ElectroHome {product_name_key}?", | |
f"What's your best rated {category_key}?", | |
f"What {category_key}s does ElectroHome offer?", | |
f"Is the {product_name_key} a good {category_key}?", | |
f"I'm considering buying a new {category_key}. What should I think about and what models would you recommend?", | |
f"Tell me about your {category_key}s." | |
] | |
return choice(templates) | |
def get_random_values() -> Dict[str, str]: | |
""" | |
Generates a random set of entries for a question | |
""" | |
category: Category = choice(list(Category.all.values())) | |
category_name = category.name[:-1].lower() | |
if category_name == "tv": | |
category_name = "TV" | |
product: Product = choice(category.products) | |
features: List[Feature] = sample(category.features, k=randint(1, 4)) | |
characteristic: str = choice([ | |
"big", | |
"durable", | |
"budget", | |
"family", | |
"compact", | |
"sleek" | |
]) | |
return { | |
category_key: category_name, | |
product_name_key: product.name, | |
features_key: join_items_comma_and([f.name for f in features]), | |
characteristic_key: characteristic | |
} | |
def merge_template_values(template: str, values: Dict[str, str]) -> str: | |
""" | |
Returns a template with the values substituted in | |
""" | |
for k, v in values.items(): | |
template = template.replace(k, v) | |
return template | |
def generate_question() -> str: | |
""" | |
Generate a single random question | |
""" | |
DataLoader.load_data() | |
template = get_random_template() | |
values = get_random_values() | |
return merge_template_values(template=template, values=values) | |
def save_questions_to_json(questions: List[str]) -> None: | |
""" | |
Persist the questions into the file | |
""" | |
data = {'questions': questions} | |
questions_file = os.path.join(data_dir, 'json', 'test_questions.json') | |
with open(questions_file, 'w') as f: | |
json.dump(data, fp=f, indent=2) | |
def generate_questions(n: int = 100): | |
""" | |
Generate n questions and overwrite the test_questions.json question bank | |
""" | |
questions = [generate_question() for _ in range(n)] | |
save_questions_to_json(questions=questions) | |
if __name__ == "__main__": | |
""" | |
Run from the command line with the number of questions to generate | |
""" | |
try: | |
generate_questions(int(sys.argv[1])) | |
except IndexError: | |
generate_questions() | |