Spaces:
Sleeping
Sleeping
skylersterling
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -14,17 +14,18 @@ model.to('cpu')
|
|
14 |
|
15 |
# Define the function that generates text from a prompt
|
16 |
def generate_text(prompt, temperature):
|
|
|
17 |
print(prompt)
|
18 |
|
19 |
-
prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# "
|
20 |
input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
|
21 |
|
22 |
input_tokens = input_tokens.to('cpu')
|
23 |
|
24 |
-
generated_text = prompt_with_eos
|
25 |
prompt_length = len(generated_text)
|
26 |
|
27 |
-
for _ in range(80):
|
28 |
with torch.no_grad():
|
29 |
outputs = model(input_tokens)
|
30 |
predictions = outputs.logits[:, -1, :] / temperature
|
@@ -33,37 +34,21 @@ def generate_text(prompt, temperature):
|
|
33 |
input_tokens = torch.cat((input_tokens, next_token), dim=1)
|
34 |
|
35 |
decoded_token = tokenizer.decode(next_token.item())
|
36 |
-
generated_text += decoded_token
|
37 |
-
if decoded_token == "#":
|
38 |
break
|
39 |
-
yield generated_text[prompt_length:]
|
40 |
-
|
41 |
-
# Custom CSS for a modern look
|
42 |
-
custom_css = """
|
43 |
-
body {font-family: 'Arial', sans-serif; background-color: #f0f2f5; margin: 0; padding: 0;}
|
44 |
-
.gradio-container {width: 100%; max-width: 900px; margin: 5% auto; background: white; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 10px; overflow: hidden;}
|
45 |
-
input, textarea {width: 100%; padding: 10px; border: 1px solid #ccc; border-radius: 5px; margin-top: 5px; font-size: 16px;}
|
46 |
-
button {background-color: #4CAF50; color: white; padding: 10px 20px; border: none; border-radius: 5px; font-size: 16px; cursor: pointer;}
|
47 |
-
button:hover {background-color: #45a049;}
|
48 |
-
h1, h3 {color: #333;}
|
49 |
-
.slider {width: 100%; margin-top: 20px;}
|
50 |
-
footer {text-align: center; padding: 20px; background-color: #4CAF50; color: white; font-size: 16px;}
|
51 |
-
"""
|
52 |
|
53 |
# Create a Gradio interface with a text input and a slider for temperature
|
54 |
interface = gr.Interface(
|
55 |
fn=generate_text,
|
56 |
inputs=[
|
57 |
-
gr.Textbox(lines=2, placeholder="Enter your prompt here..."
|
58 |
-
gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"
|
59 |
],
|
60 |
-
outputs=gr.Textbox(
|
61 |
live=False,
|
62 |
-
description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
|
63 |
-
title="TopicGPT: Theme & Topic Guessing",
|
64 |
-
layout="vertical",
|
65 |
-
theme="huggingface",
|
66 |
-
css=custom_css,
|
67 |
)
|
68 |
|
69 |
-
interface.launch()
|
|
|
14 |
|
15 |
# Define the function that generates text from a prompt
|
16 |
def generate_text(prompt, temperature):
|
17 |
+
|
18 |
print(prompt)
|
19 |
|
20 |
+
prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# " # Add the string "EOS" to the end of the prompt
|
21 |
input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
|
22 |
|
23 |
input_tokens = input_tokens.to('cpu')
|
24 |
|
25 |
+
generated_text = prompt_with_eos # Start with the initial prompt plus "EOS"
|
26 |
prompt_length = len(generated_text)
|
27 |
|
28 |
+
for _ in range(80): # Adjust the range to control the number of tokens generated
|
29 |
with torch.no_grad():
|
30 |
outputs = model(input_tokens)
|
31 |
predictions = outputs.logits[:, -1, :] / temperature
|
|
|
34 |
input_tokens = torch.cat((input_tokens, next_token), dim=1)
|
35 |
|
36 |
decoded_token = tokenizer.decode(next_token.item())
|
37 |
+
generated_text += decoded_token # Append the new token to the generated text
|
38 |
+
if decoded_token == "#": # Stop if the end of sequence token is generated
|
39 |
break
|
40 |
+
yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# Create a Gradio interface with a text input and a slider for temperature
|
43 |
interface = gr.Interface(
|
44 |
fn=generate_text,
|
45 |
inputs=[
|
46 |
+
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
|
47 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
|
48 |
],
|
49 |
+
outputs=gr.Textbox(),
|
50 |
live=False,
|
51 |
+
description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
|
54 |
+
interface.launch()
|