ctn8176 commited on
Commit
d903275
·
verified ·
1 Parent(s): 78a066f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -1,17 +1,13 @@
1
- from transformers import pipeline, Conversation
 
2
  import gradio as gr
3
 
4
- chatbot = pipeline(model="Writer/palmyra-small")
 
 
 
5
 
6
- message_list = []
7
- response_list = []
8
 
9
- def vanilla_chatbot(message, history):
10
- conversation = Conversation(text=message, past_user_inputs=message_list, generated_responses=response_list)
11
- conversation = chatbot(conversation)
12
-
13
- return conversation.generated_responses[-1]
14
-
15
- demo_chatbot = gr.ChatInterface(vanilla_chatbot, title="Writer - Palmyra-small Chatbot", description="Enter text to start chatting.")
16
-
17
- demo_chatbot.launch()
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
3
  import gradio as gr
4
 
5
+ model_name = "Writer/palmyra-small"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
9
 
10
+ text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
 
11
 
12
+ iface = gr.Interface(fn=text_generator, inputs="text", outputs="text")
13
+ iface.launch()