MohamedTalaat91 commited on
Commit
16f0fc5
·
verified ·
1 Parent(s): a87ff18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -18
app.py CHANGED
@@ -1,18 +1,15 @@
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import gradio as gr
3
 
4
- # Load the model and tokenizer
5
  model = AutoModelForCausalLM.from_pretrained("MohamedTalaat91/gpt2-wikitext2")
6
  tokenizer = AutoTokenizer.from_pretrained("MohamedTalaat91/gpt2-tokenizer")
7
 
8
- # Updated generate function
9
- def generate(messages, state):
10
- # The last message content is a dictionary with "role" and "content"
11
- input_text = messages[-1] # Extracting the content of the last message
12
-
13
- # Tokenize the input
14
  inputs = tokenizer(input_text, return_tensors="pt")
15
-
16
  # Generate text based on the input
17
  generated_ids = model.generate(
18
  inputs['input_ids'],
@@ -22,16 +19,21 @@ def generate(messages, state):
22
  top_k=50, # Top-k sampling to introduce diversity
23
  temperature=0.7 # Controls randomness in sampling
24
  )
25
-
26
- # Decode the generated text
27
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
28
-
29
- # Prepare the response in the correct format
30
- bot_message = {"role": "bot", "content": generated_text}
31
- messages.append(bot_message) # Add the bot's message to the conversation
32
-
33
- return messages, state # Return the updated messages and state
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Launch Gradio interface with the updated function
37
- gr.ChatInterface(generate, type="messages").launch(share=True)
 
1
+ # Load model directly
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
5
+
6
  model = AutoModelForCausalLM.from_pretrained("MohamedTalaat91/gpt2-wikitext2")
7
  tokenizer = AutoTokenizer.from_pretrained("MohamedTalaat91/gpt2-tokenizer")
8
 
9
+
10
+
11
+ def generate(input_text) :
 
 
 
12
  inputs = tokenizer(input_text, return_tensors="pt")
 
13
  # Generate text based on the input
14
  generated_ids = model.generate(
15
  inputs['input_ids'],
 
19
  top_k=50, # Top-k sampling to introduce diversity
20
  temperature=0.7 # Controls randomness in sampling
21
  )
 
 
22
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
23
 
24
+ return generated_text
25
+
26
+
27
+ import gradio as gr
28
+
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown("# GPT-2 WikiText2")
31
+ with gr.Row():
32
+ with gr.Column():
33
+ input_text = gr.Textbox(label="Input Text")
34
+ generate_button = gr.Button("Generate")
35
+ output_text = gr.Textbox(label="Generated Text")
36
+
37
+ generate_button.click(fn=generate, inputs=input_text, outputs=output_text)
38
 
39
+ demo.launch(share=True)