Spaces:
Sleeping
Sleeping
File size: 2,130 Bytes
c95822b 14ac587 db9d4db 14ac587 db9d4db 5e2cdce db9d4db 22fb6e4 ad4d3e1 db9d4db a819a8c 5cc20f3 51227f1 f7d422b 5cc20f3 013060a 38c9c31 1a13068 403b4a5 5cc20f3 46d003d b83610f 5cc20f3 14ac587 3aeba0a 14ac587 5cc20f3 c4dece9 5cc20f3 db9d4db a819a8c c4dece9 5cc20f3 c4dece9 5cc20f3 6532f26 5cc20f3 c4dece9 6e4b802 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os
# Get the Hugging Face token from the environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
model.to('cpu')
# Define the function that generates text from a prompt
def generate_text(prompt, temperature):
print(prompt)
prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# " # Add the string "EOS" to the end of the prompt
input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
input_tokens = input_tokens.to('cpu')
generated_text = prompt_with_eos # Start with the initial prompt plus "EOS"
prompt_length = len(generated_text)
for _ in range(80): # Adjust the range to control the number of tokens generated
with torch.no_grad():
outputs = model(input_tokens)
predictions = outputs.logits[:, -1, :] / temperature
next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
input_tokens = torch.cat((input_tokens, next_token), dim=1)
decoded_token = tokenizer.decode(next_token.item())
generated_text += decoded_token # Append the new token to the generated text
if decoded_token == "#": # Stop if the end of sequence token is generated
break
yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
# Create a Gradio interface with a text input and a slider for temperature
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
],
outputs=gr.Textbox(),
live=False,
description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
)
interface.launch()
|