TopicGPT / app.py
skylersterling's picture
Update app.py
6e4b802 verified
raw
history blame
2.13 kB
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os
# Get the Hugging Face token from the environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
model.to('cpu')
# Define the function that generates text from a prompt
def generate_text(prompt, temperature):
print(prompt)
prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# " # Add the string "EOS" to the end of the prompt
input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
input_tokens = input_tokens.to('cpu')
generated_text = prompt_with_eos # Start with the initial prompt plus "EOS"
prompt_length = len(generated_text)
for _ in range(80): # Adjust the range to control the number of tokens generated
with torch.no_grad():
outputs = model(input_tokens)
predictions = outputs.logits[:, -1, :] / temperature
next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
input_tokens = torch.cat((input_tokens, next_token), dim=1)
decoded_token = tokenizer.decode(next_token.item())
generated_text += decoded_token # Append the new token to the generated text
if decoded_token == "#": # Stop if the end of sequence token is generated
break
yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
# Create a Gradio interface with a text input and a slider for temperature
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
],
outputs=gr.Textbox(),
live=False,
description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
)
interface.launch()