skylersterling commited on
Commit
c95822b
·
verified ·
1 Parent(s): a01c238

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ # Load the model and tokenizer
5
+ model_name = "skylersterling/TopicGPT"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(model_name)
8
+
9
+ # Define the generation function
10
+ def generate_text(context, max_tokens):
11
+ input_text = f"#CONTEXT# {context} #TOPIC#"
12
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
13
+
14
+ # Generate tokens one by one
15
+ generated_ids = input_ids
16
+ for _ in range(max_tokens):
17
+ outputs = model(generated_ids)
18
+ next_token_id = outputs.logits[:, -1, :].argmax(dim=-1)
19
+ generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0)], dim=1)
20
+ if next_token_id == tokenizer.encode("#TOPIC#", add_special_tokens=False)[0]:
21
+ break
22
+
23
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
24
+ return generated_text
25
+
26
+ # Create Gradio interface
27
+ def gradio_interface():
28
+ context_input = gr.inputs.Textbox(lines=5, placeholder="Enter the context here...")
29
+ max_tokens_input = gr.inputs.Slider(minimum=1, maximum=200, default=50, step=1)
30
+ output_textbox = gr.outputs.Textbox()
31
+
32
+ interface = gr.Interface(
33
+ fn=generate_text,
34
+ inputs=[context_input, max_tokens_input],
35
+ outputs=output_textbox,
36
+ title="TopicGPT Text Generation",
37
+ description="Generate text token-by-token using the TopicGPT model. The input should start with #CONTEXT# and end with #TOPIC#."
38
+ )
39
+
40
+ interface.launch()
41
+
42
+ if __name__ == "__main__":
43
+ gradio_interface()