flashcard-studio / app /pipeline.py
Nathan Slaughter
improve prompt and interface
cf1cddb
import logging
import torch
from transformers import pipeline
logger = logging.getLogger(__name__)
logging.basicConfig(filename="pipeline.log", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
class Pipeline:
def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
self.torch_pipe = pipeline(
"text-generation",
model_name,
torch_dtype="auto",
device_map="auto"
)
self.device = self._determine_device()
logger.info(f"device type: {self.device}")
self.messages = [
{"role": "system", "content": """You are an expert flashcard creator.
- You ALWAYS include a single knowledge item per flashcard.
- You ALWAYS respond in valid JSON format.
- You ALWAYS make flashcards accurate and comprehensivce.
- If the text includes code snippets, you consider snippets a knowledge item testing the user's understanding of how to write the code and what it does.
Format responses like the example below.
EXAMPLE:
[
{"question": "What is AI?", "answer": "Artificial Intelligence."},
{"question": "What is ML?", "answer": "Machine Learning."}
]
"""},
]
def extract_flashcards(self, content: str = "", max_new_tokens: int = 1024) -> str:
user_prompt = {"role": "user", "content": content}
self.messages.append(user_prompt)
try:
response_message = self.torch_pipe(
self.messages,
max_new_tokens=max_new_tokens
)[0]["generated_text"][-1]
return response_message
except Exception as e:
logger.error(f"Error extracting flashcards: {str(e)}")
raise ValueError(f"Error extraction flashcards: {str(e)}")
def _determine_device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")