Yuchan5386's picture
Update app.py
e069dc2 verified
raw
history blame
1.45 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_NAME = "lcw99/ko-dialoGPT-korean-chit-chat"
# ํ† ํฌ๋‚˜์ด์ € ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
def chat_with_ai(history, message):
input_text = message + tokenizer.eos_token
input_ids = tokenizer.encode(input_text, return_tensors="pt")
response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
history.append((message, response_text)) # Gradio Chatbot ํ˜•์‹ ์œ ์ง€
return history, "" # ์ž…๋ ฅ์ฐฝ ๋น„์šฐ๊ธฐ
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿ—จ๏ธ Ko-DialoGPT Chatbot")
chatbot = gr.Chatbot(label="Ko-DialoGPT Chatbot")
message = gr.Textbox(label="์ž…๋ ฅ ๋ฉ”์‹œ์ง€", placeholder="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...")
clear_btn = gr.Button("์ดˆ๊ธฐํ™”")
# ๋ฉ”์‹œ์ง€ ์ž…๋ ฅ ์‹œ ์ฑ„ํŒ… ๊ธฐ๋ก ์—…๋ฐ์ดํŠธ ๋ฐ ์ž…๋ ฅ์ฐฝ ์ดˆ๊ธฐํ™”
message.submit(chat_with_ai, [chatbot, message], [chatbot, message])
# ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ฑ„ํŒ… ๊ธฐ๋ก ์‚ญ์ œ
clear_btn.click(lambda: [], [], chatbot)
# โœ… ์„œ๋ฒ„ ํฌํŠธ ๋ฐ ์ฃผ์†Œ ์ถ”๊ฐ€
demo.launch(server_name="0.0.0.0", server_port=7860)