ctn8176 commited on
Commit
0a2c880
·
verified ·
1 Parent(s): d903275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
3
  import gradio as gr
4
 
5
  model_name = "Writer/palmyra-small"
@@ -7,7 +7,28 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
  model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
9
 
10
- text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
 
 
 
 
 
 
11
 
12
- iface = gr.Interface(fn=text_generator, inputs="text", outputs="text")
13
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
5
  model_name = "Writer/palmyra-small"
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
  model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
9
 
10
+ def generate_response(prompt):
11
+ input_text_template = (
12
+ "A chat between a curious user and an artificial intelligence assistant. "
13
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. "
14
+ f"USER: {prompt} "
15
+ "ASSISTANT:"
16
+ )
17
 
18
+ model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
19
+
20
+ gen_conf = {
21
+ "top_k": 20,
22
+ "max_length": 200,
23
+ "temperature": 0.6,
24
+ "do_sample": True,
25
+ "eos_token_id": tokenizer.eos_token_id,
26
+ }
27
+
28
+ output = model.generate(**model_inputs, **gen_conf)
29
+
30
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
31
+ return generated_text
32
+
33
+ iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
34
+ iface.launch()