Spaces:
Running
Running
import gradio as gr | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from openai import OpenAI | |
from translate_utils import translate_ko_to_en | |
YOUR_OPENAI_API_KEY = "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" | |
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ | |
embedding_model_name = "snunlp/KR-SBERT-V40K-klueNLI-augSTS" | |
embedding_model = HuggingFaceEmbeddings( | |
model_name=embedding_model_name, | |
model_kwargs={"device": "cpu"}, # GPU ์ฌ์ฉ ์ "cuda"๋ก ๋ณ๊ฒฝ | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
# vector DB ๋ก๋ | |
save_path = "./version-2024-12-22" | |
vectorstore = FAISS.load_local(save_path, embedding_model, allow_dangerous_deserialization=True) | |
def chatbot(input_question, eng_trans=True, num_ref=3): | |
"""์ฑ๋ด ํจ์""" | |
retriever = vectorstore.as_retriever(search_kwargs={"k": num_ref}) | |
# ํ๊ตญ์ด๋ก ์ง์ | |
if not eng_trans: | |
basic_docs = retriever.invoke(input_question) | |
# ์์ด๋ก ๋ฒ์ญ ํ ์ง์(์์ด, ํ๊ธ ๋ ๋ค ๊ฒ์) | |
else: | |
eng = translate_ko_to_en(input_question) | |
basic_docs = retriever.invoke(input_question) | |
eng_docs = retriever.invoke(eng) | |
basic_docs = basic_docs + eng_docs | |
context = "\n".join([doc.page_content for doc in basic_docs]) | |
client = OpenAI(api_key=YOUR_OPENAI_API_KEY) | |
# GPT-4 or GPT-4o-mini ๋ฑ ๋ชจ๋ธ ์ง์ | |
completion = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "user", | |
"content": f"""๋น์ ์ ๋ฐ๋์ฒด์ ์ฐจ์ธ๋๋ฐ๋์ฒดํ๊ณผ์ ๋ํด์ ์ค๋ช ํ๋ Assistant์ ๋๋ค. | |
์ฐจ์ธ๋๋ฐ๋์ฒดํ๊ณผ๋ ํ๊ตญ ์์ธ์ ์ค์๋ํ๊ต์ ์ฐฝ์ICT๊ณต๊ณผ๋ํ์ ์ค๋ฆฝ๋ ํ๊ณผ์ ๋๋ค. ์ด ํ๊ณผ๋ ์์ธ๋ํ๊ต, ์ค์๋ํ๊ต, ํฌํญ๊ณต๊ณผ๋ํ๊ต, ์ญ์ค๋ํ๊ต, ๊ฐ์๋ํ๊ต, ๋๊ตฌ๋ํ๊ต, ์กฐ์ ์ด๊ณต๋ํ๊ต๊ฐ ํ์ ๊ต๋ฅ๋ฅผ ํตํด ์๊ฐํ ์ ์๋ ํ๊ณผ์ ๋๋ค. | |
๋ค์ ๋งฅ๋ฝ์ ๋ง๊ฒ ์ง๋ฌธ์ ํ๊ธ๋ก ๋ตํ์ธ์. ๋์ , ๋ฐ๋์ฒด ์ ๋ฌธ์ฉ์ด๋ ์์ด๋ก ํด๋ ๋ฉ๋๋ค. | |
๋งฅ๋ฝ: {context} | |
์ง๋ฌธ: {input_question} | |
""" | |
} | |
] | |
) | |
return completion.choices[0].message.content | |
# Gradio Blocks ๋ ์ด์์์ผ๋ก ๊พธ๋ฏผ ๋ฒ์ | |
with gr.Blocks() as demo: | |
# ์๋จ ์ด๋ฏธ์ง | |
gr.Image( | |
value="head.png", # ์ด๋ฏธ์ง ๊ฒฝ๋ก ๋๋ URL | |
elem_id="top-image", | |
label=None | |
) | |
# ์๋ด ๋ฌธ๊ตฌ | |
gr.Markdown( | |
""" | |
# ์ฐจ์ธ๋๋ฐ๋์ฒดํ๊ณผ ํนํ ์ฑ๋ด | |
- ์ด ์ฑ๋ด์ ํ๊ตญ์ด ์ง์์ ๋ํด ๋ฐ๋์ฒด ๊ด๋ จ ์ ๋ณด, ์ฐจ์ธ๋๋ฐ๋์ฒดํ๊ณผ ๊ด๋ จ ์ ๋ณด๋ฅผ ์น์ ํ๊ฒ ์ ๊ณตํฉ๋๋ค.<br> | |
- ๋ฐ๋์ฒด ์ ๋ฌธ์ฉ์ด๋ ์ผ๋ถ ์์ด๋ก ๋ต๋ณ๋ ์ ์์ต๋๋ค.<br> | |
### ์์ด ๋ฒ์ญ ์ฌ์ฉ ์ฌ๋ถ ๊ธฐ๋ฅ | |
- ์์ด ๋ฒ์ญ ์ฌ์ฉ ์ฌ๋ถ๋ฅผ ํค๋ฉด, ์ ํ๋๊ฐ ์์น๋ ์ ์์ผ๋, ์ถ๋ก ์๊ฐ์ด ๊ธธ์ด์ง ์ ์์ต๋๋ค. <br> | |
### ๊ฒ์ ๋ฌธ์ ๊ฐ์ ๋ณ๊ฒฝ ๊ธฐ๋ฅ | |
๊ฒ์ํ ๋ฌธ์ ๊ฐ์๋ฅผ ๋๋ฆฌ๋ฉด ์ ํ๋๊ฐ ์์น๋ ์ ์์ผ๋, ์ถ๋ก ์๊ฐ์ด ๊ธธ์ด์ง ์ ์์ต๋๋ค. <br> | |
""", | |
elem_id="description" | |
) | |
# ๋ฉ์ธ UI | |
with gr.Group(): | |
with gr.Row(): | |
# ์ ๋ ฅ ํํธ | |
with gr.Column(): | |
input_question = gr.Textbox( | |
label="์ง๋ฌธ ์ ๋ ฅ", | |
placeholder="๋ฐ๋์ฒด์ ์ฐจ์ธ๋๋ฐ๋์ฒดํ๊ณผ์ ๋ํด ๊ถ๊ธํ ์ ์ ์ ๋ ฅํ์ธ์." | |
) | |
eng_trans = gr.Checkbox( | |
label="์์ด ๋ฒ์ญ ์ฌ์ฉ ์ฌ๋ถ", | |
value=True | |
) | |
num_ref = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=3, | |
step=1, | |
label="๊ฒ์ํ ๋ฌธ์ ๊ฐ์" | |
) | |
submit_btn = gr.Button("์ง๋ฌธํ๊ธฐ") | |
# ์ถ๋ ฅ ํํธ | |
with gr.Column(): | |
output_answer = gr.Textbox( | |
label="๋ต๋ณ", | |
placeholder="๋ต๋ณ์ด ์ฌ๊ธฐ์ ํ์๋ฉ๋๋ค...", | |
lines=10 | |
) | |
# ๋ฒํผ๊ณผ ํจ์ ์ฐ๊ฒฐ | |
submit_btn.click( | |
fn=chatbot, | |
inputs=[input_question, eng_trans, num_ref], | |
outputs=output_answer | |
) | |
demo.launch() | |