File size: 2,130 Bytes
c95822b
14ac587
db9d4db
 
 
14ac587
db9d4db
 
 
5e2cdce
db9d4db
22fb6e4
ad4d3e1
db9d4db
 
a819a8c
5cc20f3
51227f1
f7d422b
5cc20f3
013060a
38c9c31
1a13068
403b4a5
5cc20f3
46d003d
b83610f
5cc20f3
14ac587
 
3aeba0a
 
14ac587
 
 
 
5cc20f3
 
c4dece9
5cc20f3
db9d4db
a819a8c
c4dece9
 
 
5cc20f3
 
c4dece9
5cc20f3
6532f26
5cc20f3
c4dece9
 
6e4b802
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os

# Get the Hugging Face token from the environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")

# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
model.to('cpu')

# Define the function that generates text from a prompt
def generate_text(prompt, temperature):

    print(prompt)
    
    prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# "  # Add the string "EOS" to the end of the prompt
    input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')

    input_tokens = input_tokens.to('cpu')

    generated_text = prompt_with_eos  # Start with the initial prompt plus "EOS"
    prompt_length = len(generated_text)

    for _ in range(80):  # Adjust the range to control the number of tokens generated
        with torch.no_grad():
            outputs = model(input_tokens)
            predictions = outputs.logits[:, -1, :] / temperature
            next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
        
        input_tokens = torch.cat((input_tokens, next_token), dim=1)
        
        decoded_token = tokenizer.decode(next_token.item())
        generated_text += decoded_token  # Append the new token to the generated text
        if decoded_token == "#":  # Stop if the end of sequence token is generated
            break
        yield generated_text[prompt_length:]  # Yield the generated text excluding the initial prompt plus "EOS"

# Create a Gradio interface with a text input and a slider for temperature
interface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
    ],
    outputs=gr.Textbox(),
    live=False,
    description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
)

interface.launch()