skylersterling commited on
Commit
51227f1
·
verified ·
1 Parent(s): f7d422b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -18
app.py CHANGED
@@ -14,24 +14,17 @@ model.to('cpu')
14
 
15
  # Define the function that generates text from a prompt
16
  def generate_text(prompt, temperature):
17
- word_count = len(prompt.split())
18
- if word_count < 5:
19
- return "Please provide at least 5 words in the prompt."
20
 
21
- # Tokenize the prompt to check the number of tokens
22
- input_tokens = tokenizer.encode(prompt, return_tensors='pt')
23
- if input_tokens.size(1) > 512:
24
- return "Please provide an input with fewer than 512 tokens."
25
-
26
- prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# " # Add the string "EOS" to the end of the prompt
27
  input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
28
 
29
  input_tokens = input_tokens.to('cpu')
30
 
31
- generated_text = prompt_with_eos # Start with the initial prompt plus "EOS"
32
  prompt_length = len(generated_text)
33
 
34
- for _ in range(80): # Adjust the range to control the number of tokens generated
35
  with torch.no_grad():
36
  outputs = model(input_tokens)
37
  predictions = outputs.logits[:, -1, :] / temperature
@@ -40,21 +33,37 @@ def generate_text(prompt, temperature):
40
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
41
 
42
  decoded_token = tokenizer.decode(next_token.item())
43
- generated_text += decoded_token # Append the new token to the generated text
44
- if decoded_token == "#": # Stop if the end of sequence token is generated
45
  break
46
- yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Create a Gradio interface with a text input and a slider for temperature
49
  interface = gr.Interface(
50
  fn=generate_text,
51
  inputs=[
52
- gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
53
- gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
54
  ],
55
- outputs=gr.Textbox(),
56
  live=False,
57
- description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation."
 
 
 
 
58
  )
59
 
60
  interface.launch()
 
14
 
15
  # Define the function that generates text from a prompt
16
  def generate_text(prompt, temperature):
17
+ print(prompt)
 
 
18
 
19
+ prompt_with_eos = "#CONTEXT# " + prompt + " #TOPIC# "
 
 
 
 
 
20
  input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')
21
 
22
  input_tokens = input_tokens.to('cpu')
23
 
24
+ generated_text = prompt_with_eos
25
  prompt_length = len(generated_text)
26
 
27
+ for _ in range(80):
28
  with torch.no_grad():
29
  outputs = model(input_tokens)
30
  predictions = outputs.logits[:, -1, :] / temperature
 
33
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
34
 
35
  decoded_token = tokenizer.decode(next_token.item())
36
+ generated_text += decoded_token
37
+ if decoded_token == "#":
38
  break
39
+ yield generated_text[prompt_length:]
40
+
41
+ # Custom CSS for a modern look
42
+ custom_css = """
43
+ body {font-family: 'Arial', sans-serif; background-color: #f0f2f5; margin: 0; padding: 0;}
44
+ .gradio-container {width: 100%; max-width: 900px; margin: 5% auto; background: white; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 10px; overflow: hidden;}
45
+ input, textarea {width: 100%; padding: 10px; border: 1px solid #ccc; border-radius: 5px; margin-top: 5px; font-size: 16px;}
46
+ button {background-color: #4CAF50; color: white; padding: 10px 20px; border: none; border-radius: 5px; font-size: 16px; cursor: pointer;}
47
+ button:hover {background-color: #45a049;}
48
+ h1, h3 {color: #333;}
49
+ .slider {width: 100%; margin-top: 20px;}
50
+ footer {text-align: center; padding: 20px; background-color: #4CAF50; color: white; font-size: 16px;}
51
+ """
52
 
53
  # Create a Gradio interface with a text input and a slider for temperature
54
  interface = gr.Interface(
55
  fn=generate_text,
56
  inputs=[
57
+ gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
58
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature", step=0.1),
59
  ],
60
+ outputs=gr.Textbox(label="Generated Text"),
61
  live=False,
62
+ description="TopicGPT processes the input and returns a reasonably accurate guess of the topic/theme of a given conversation.",
63
+ title="TopicGPT: Theme & Topic Guessing",
64
+ layout="vertical",
65
+ theme="huggingface",
66
+ css=custom_css,
67
  )
68
 
69
  interface.launch()