mesop-jeopardy-live / scripts /generate_clues.py
Richard
Add custom clue generation
28ab31e
raw
history blame
8.5 kB
import json
import os
import typing
import time
from datetime import datetime
import random
import argparse
from dotenv import load_dotenv
import google.generativeai as genai
load_dotenv()
# Flash 1.5 has a requestion limit of 15 RPM
GENERATE_DELAY = 5
DEFAULT_JEOPARDY_DATA = "../data/custom_jeopardy.json"
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
# Global variable to store the current dataset path
current_dataset_path = DEFAULT_JEOPARDY_DATA
genai.configure(api_key=GOOGLE_API_KEY)
class JeopardyQuestion(typing.TypedDict):
question: str
answer: str
value: str
question_gen_model = genai.GenerativeModel(
"gemini-1.5-flash",
generation_config=genai.GenerationConfig(
temperature=1,
top_p=0.95,
top_k=64,
response_mime_type="application/json",
response_schema=list[JeopardyQuestion],
),
)
_JEOPARDY_QUESTION_GENERATE_PROMPT = """
You are a Jeopardy! expert who specializes in crafting great questions.
Generate Jeopardy! questions for the following category: {category}.
A Jeopardy! category has 5 questions of increasing difficulty. The values are $200,
$400, $600, $800, $1000.
""".strip()
def get_existing_categories() -> set[str]:
"""Get a set of all existing categories in the dataset."""
try:
with open(current_dataset_path, "r") as f:
data = json.load(f)
return {item["category"].lower() for item in data}
except FileNotFoundError:
return set()
except json.JSONDecodeError:
print(f"Warning: Error reading {current_dataset_path}. Treating as empty file.")
return set()
def read_custom_jeopardy_questions_dataset():
try:
with open(current_dataset_path, "r") as f:
return json.load(f)
except FileNotFoundError:
return []
def write_custom_jeopardy_questions_dataset(data, overwrite=False):
"""Write questions to the dataset file.
Args:
data: List of questions to write
overwrite: If True, replace existing data. If False, append to existing data.
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(current_dataset_path), exist_ok=True)
if overwrite:
# In overwrite mode, simply write the new data
with open(current_dataset_path, "w") as f:
json.dump(data, f, indent=2)
return
# In append mode
if not os.path.exists(current_dataset_path):
# If file doesn't exist, create it with the new data
with open(current_dataset_path, "w") as f:
json.dump(data, f, indent=2)
return
try:
# Read existing data
with open(current_dataset_path, "r") as f:
existing_data = json.load(f)
# Append new data
existing_data.extend(data)
# Write combined data
with open(current_dataset_path, "w") as f:
json.dump(existing_data, f, indent=2)
except json.JSONDecodeError:
# If file is empty or invalid, just write new data
with open(current_dataset_path, "w") as f:
json.dump(data, f, indent=2)
def generate_questions_by_category(category) -> list[dict[str, str]]:
"""Generate Jeopardy questions for a category using Gemini.
Returns:
Generated jeopardy data set in the expected format.
"""
time.sleep(5)
questions = json.loads(
question_gen_model.generate_content(
_JEOPARDY_QUESTION_GENERATE_PROMPT.format(category=category)
).text
)
questions_list = []
air_date = datetime.now().strftime("%Y-%m-%d")
show_number = str(random.randint(1, 2000))
# Format the questions like the data set.
for question in questions:
questions_list.append(
{
"question": question["question"],
"answer": question["answer"],
"value": question["value"],
"category": category,
"air_date": air_date,
"show_number": show_number,
"round": "Jeopardy!",
}
)
return questions_list
def print_questions(questions: list[dict[str, str]], category: str):
"""Print the generated questions in a readable format."""
print(f"\nCategory: {category}\n")
print("-" * 50)
# Sort questions by value
sorted_questions = sorted(
questions, key=lambda x: int(x["value"].replace("$", "").replace(",", ""))
)
for q in sorted_questions:
print(f"Value: {q['value']}")
print(f"Question: {q['question']}")
print(f"Answer: {q['answer']}")
print("-" * 50)
def get_categories_from_input() -> list[str]:
"""Get multiple categories from user input."""
categories = []
print("Enter categories (one per line). Press Enter twice when done:")
while True:
category = input().strip()
if not category:
break
categories.append(category)
return categories
def read_categories_from_file(filename: str) -> list[str]:
"""Read categories from a text file, one category per line."""
try:
with open(filename, "r") as f:
# Read lines and remove empty lines and whitespace
categories = [line.strip() for line in f if line.strip()]
return categories
except FileNotFoundError:
print(f"Error: File '{filename}' not found")
return []
except Exception as e:
print(f"Error reading file: {e}")
return []
def filter_existing_categories(categories: list[str]) -> list[str]:
"""Filter out categories that already exist in the dataset."""
existing_categories = get_existing_categories()
new_categories = []
skipped_categories = []
for category in categories:
if category.lower() in existing_categories:
skipped_categories.append(category)
else:
new_categories.append(category)
if skipped_categories:
print("\nSkipping the following existing categories:")
for category in skipped_categories:
print(f"- {category}")
return new_categories
def main():
parser = argparse.ArgumentParser(
description="Generate Jeopardy questions for multiple categories"
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--categories", nargs="+", type=str, help="List of Jeopardy categories to generate questions"
)
group.add_argument(
"--file", type=str, help="Path to text file containing categories (one per line)"
)
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite existing questions instead of appending"
)
parser.add_argument(
"--dataset", type=str, help=f"Path to the dataset file (default: {DEFAULT_JEOPARDY_DATA})"
)
args = parser.parse_args()
# Set the dataset path
global current_dataset_path
if args.dataset:
current_dataset_path = args.dataset
print(f"Using custom dataset path: {current_dataset_path}")
# Determine which source to use for categories
if args.file:
categories = read_categories_from_file(args.file)
if not categories:
print("No valid categories found in file. Exiting.")
return
elif args.categories:
categories = args.categories
else:
categories = get_categories_from_input()
if not categories:
print("No categories provided. Exiting.")
return
# Filter out existing categories
if not args.overwrite:
categories = filter_existing_categories(categories)
if not categories:
print("\nAll categories already exist in the dataset. Nothing to do.")
return
print(f"\nPreparing to generate questions for {len(categories)} categories:")
for i, category in enumerate(categories, 1):
print(f"{i}. {category}")
print()
if args.overwrite:
print("Warning: This will overwrite all existing questions!")
confirm = input("Do you want to continue? (y/N): ")
if confirm.lower() != "y":
print("Operation cancelled.")
return
# If overwrite mode, initialize empty dataset
if args.overwrite:
write_custom_jeopardy_questions_dataset([], overwrite=True)
print("Initialized empty dataset for overwrite mode")
for i, category in enumerate(categories, 1):
print(f"\nGenerating questions for category: {category} ({i}/{len(categories)})")
try:
questions = generate_questions_by_category(category)
print_questions(questions, category)
# Save after each category
write_custom_jeopardy_questions_dataset(questions, overwrite=False)
print(f"✓ Saved questions for {category}")
except Exception as e:
print(f"Error generating questions for {category}: {str(e)}")
print("Skipping to next category...")
continue
print(f"\nCompleted processing {len(categories)} categories.")
print(f"All generated questions have been saved to {current_dataset_path}")
if __name__ == "__main__":
main()