|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class TopicGenerator: |
|
|
|
def __init__(self): |
|
|
|
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") |
|
self.model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to(device) |
|
|
|
def generate_topics(self, user_input, num_topics=3): |
|
""" |
|
Generate topic sentences based on the user input. |
|
|
|
Args: |
|
- user_input (str): The input text provided by the user. |
|
- num_topics (int, optional): Number of topics to generate. Defaults to 3. |
|
|
|
Returns: |
|
- list: A list of generated topic sentences. |
|
""" |
|
prompt_text = f"Generate a topic sentence based on the following input: {user_input}" |
|
input_ids = self.tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
outputs = self.model.generate(input_ids, do_sample=True, top_k=50, temperature=0.7, max_length=50, num_return_sequences=num_topics) |
|
|
|
|
|
return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|