skylersterling commited on
Commit
5cc20f3
·
verified ·
1 Parent(s): 51227f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -27
app.py CHANGED
@@ -14,17 +14,18 @@ model.to('cpu')
14
 
15
  # Define the function that generates text from a prompt
16
  def generate_text(prompt, temperature):
 
17
  print(prompt)
18
 
19
- prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# "
20
  input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
21
 
22
  input_tokens = input_tokens.to('cpu')
23
 
24
- generated_text = prompt_with_eos
25
  prompt_length = len(generated_text)
26
 
27
- for _ in range(80):
28
  with torch.no_grad():
29
  outputs = model(input_tokens)
30
  predictions = outputs.logits[:, -1, :] / temperature
@@ -33,37 +34,21 @@ def generate_text(prompt, temperature):
33
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
34
 
35
  decoded_token = tokenizer.decode(next_token.item())
36
- generated_text += decoded_token
37
- if decoded_token == "#":
38
  break
39
- yield generated_text[prompt_length:]
40
-
41
- # Custom CSS for a modern look
42
- custom_css = """
43
- body {font-family: 'Arial', sans-serif; background-color: #f0f2f5; margin: 0; padding: 0;}
44
- .gradio-container {width: 100%; max-width: 900px; margin: 5% auto; background: white; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 10px; overflow: hidden;}
45
- input, textarea {width: 100%; padding: 10px; border: 1px solid #ccc; border-radius: 5px; margin-top: 5px; font-size: 16px;}
46
- button {background-color: #4CAF50; color: white; padding: 10px 20px; border: none; border-radius: 5px; font-size: 16px; cursor: pointer;}
47
- button:hover {background-color: #45a049;}
48
- h1, h3 {color: #333;}
49
- .slider {width: 100%; margin-top: 20px;}
50
- footer {text-align: center; padding: 20px; background-color: #4CAF50; color: white; font-size: 16px;}
51
- """
52
 
53
  # Create a Gradio interface with a text input and a slider for temperature
54
  interface = gr.Interface(
55
  fn=generate_text,
56
  inputs=[
57
- gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
58
- gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature", step=0.1),
59
  ],
60
- outputs=gr.Textbox(label="Generated Text"),
61
  live=False,
62
- description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation.",
63
- title="TopicGPT: Theme & Topic Guessing",
64
- layout="vertical",
65
- theme="huggingface",
66
- css=custom_css,
67
  )
68
 
69
- interface.launch()
 
14
 
15
  # Define the function that generates text from a prompt
16
  def generate_text(prompt, temperature):
17
+
18
  print(prompt)
19
 
20
+ prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# " # Add the string "EOS" to the end of the prompt
21
  input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
22
 
23
  input_tokens = input_tokens.to('cpu')
24
 
25
+ generated_text = prompt_with_eos # Start with the initial prompt plus "EOS"
26
  prompt_length = len(generated_text)
27
 
28
+ for _ in range(80): # Adjust the range to control the number of tokens generated
29
  with torch.no_grad():
30
  outputs = model(input_tokens)
31
  predictions = outputs.logits[:, -1, :] / temperature
 
34
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
35
 
36
  decoded_token = tokenizer.decode(next_token.item())
37
+ generated_text += decoded_token # Append the new token to the generated text
38
+ if decoded_token == "#": # Stop if the end of sequence token is generated
39
  break
40
+ yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Create a Gradio interface with a text input and a slider for temperature
43
  interface = gr.Interface(
44
  fn=generate_text,
45
  inputs=[
46
+ gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
47
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
48
  ],
49
+ outputs=gr.Textbox(),
50
  live=False,
51
+ description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
 
 
 
 
52
  )
53
 
54
+ interface.launch()