skylersterling commited on
Commit
b83610f
·
verified ·
1 Parent(s): 5e2cdce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -2,6 +2,7 @@
2
  import gradio as gr
3
  import transformers
4
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
 
5
  import os
6
 
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -12,12 +13,33 @@ model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_toke
12
  model.eval()
13
 
14
  # Define the function that generates text from a prompt
15
- def generate_text(prompt):
16
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
17
- output = model.generate(input_ids, max_new_tokens=80, do_sample=True)
18
- text = tokenizer.decode(output[0], skip_special_tokens=True)
19
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Create a gradio interface with a text input and a text output
22
- interface = gr.Interface(fn=generate_text, inputs='text', outputs='text')
23
  interface.launch()
 
2
  import gradio as gr
3
  import transformers
4
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
+ import torch
6
  import os
7
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
13
  model.eval()
14
 
15
  # Define the function that generates text from a prompt
16
+ def generate_text(prompt, temperature, top_p):
17
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
18
+ generated_text = prompt
19
+ model.eval()
20
+
21
+ with torch.no_grad():
22
+ for _ in range(80): # Generate up to 80 tokens
23
+ outputs = model(input_ids)
24
+ next_token_logits = outputs.logits[:, -1, :] / temperature
25
+ filtered_logits = transformers.TopPLogitsWarper(top_p)(input_ids, next_token_logits)
26
+ next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
27
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
28
+ generated_text += tokenizer.decode(next_token[0], skip_special_tokens=True)
29
+ if tokenizer.decode(next_token[0], skip_special_tokens=True).strip() == tokenizer.eos_token:
30
+ break
31
+
32
+ return generated_text
33
+
34
+ # Create a gradio interface with a text input, sliders for temperature and top-p, and a text output
35
+ interface = gr.Interface(
36
+ fn=generate_text,
37
+ inputs=[
38
+ gr.inputs.Textbox(lines=2, placeholder='Enter your prompt here...'),
39
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, default=1.0, label='Temperature'),
40
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.9, label='Top-p')
41
+ ],
42
+ outputs='text'
43
+ )
44
 
 
 
45
  interface.launch()