Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
import gradio as gr | |
import json | |
from tqdm import tqdm | |
from langchain_community.vectorstores import FAISS | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
import google.generativeai as genai | |
# from playwright._impl._driver import get_driver_dir | |
from helpers import ( | |
list_docx_files, get_splits, get_json_splits_only, prompt_order, log_message | |
) | |
from file_loader import get_vectorstore | |
# import asyncio | |
if "GOOGLE_API_KEY" not in os.environ: | |
os.environ["GOOGLE_API_KEY"] = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs" | |
key = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs" | |
### | |
# Cấu hình API key cho Google GenAI | |
genai.configure(api_key=key) | |
vectorstore = get_vectorstore() | |
# Define the augment_prompt function | |
def augment_prompt(query: str, k: int = 10): | |
queries = [] | |
queries.append(query) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": k}) | |
results = retriever.invoke(query) | |
if results: | |
source_knowledge = "\n\n".join([doc.page_content for doc in results]) | |
return f"""Using the contexts below, answer the query. | |
Contexts: | |
{source_knowledge} | |
""" | |
else: | |
return f"No relevant context found.\n." | |
def get_answer(query, queries_list=None): | |
if queries_list is None: | |
queries_list = [] | |
messages = [ | |
{"role": "user", "parts": [{"text": "IMPORTANT: You are a super energetic, helpful, polite, Vietnamese-speaking assistant. If you can not see the answer in contexts, try to search it up online by yourself but remember to give the source."}]}, | |
{"role": "user", "parts": [{"text": augment_prompt(query)}]} | |
] | |
# bonus = ''' | |
# Bạn tham kháo thêm các nguồn thông tin tại: | |
# Trang thông tin điện tử: https://neu.edu.vn ; https://daotao.neu.edu.vn | |
# Trang mạng xã hội có thông tin tuyển sinh: https://www.facebook.com/ktqdNEU ; https://www.facebook.com/tvtsneu ; | |
# Email tuyển sinh: [email protected] | |
# Số điện thoại tuyển sinh: 0888.128.558 | |
# ''' | |
queries_list.append(query) | |
queries = {"role": "user", "parts": [{"text": prompt_order(queries_list)}]} | |
messages_with_queries = messages.copy() | |
messages_with_queries.append(queries) | |
# messages_with_queries.insert(0, queries) | |
# Configure API key | |
genai.configure(api_key=key) | |
# Initialize the Gemini model | |
model = genai.GenerativeModel("gemini-2.0-flash") | |
response = model.generate_content(contents=messages_with_queries, stream=True) | |
response_text = "" | |
for chunk in response: | |
response_text += chunk.text | |
yield response_text | |
messages.append({"role": "model", "parts": [{"text": response_text}]}) | |
# user_feedback = yield "\nNhập phản hồi của bạn (hoặc nhập 'q' để thoát): " | |
# if user_feedback.lower() == "q": | |
# break | |
# messages.append({"role": "user", "parts": [{"text": query}]}) | |
log_message(messages) | |
institutions = ['Tất cả'] + ['Trường Công Nghệ'] | |
categories = ['Tất cả'] + ['Đề án', 'Chương trình đào tạo'] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
category1 = gr.Dropdown(choices = institutions, label="Trường", value = 'Tất cả') | |
category2 = gr.Dropdown(choices = categories, label="Bạn quan tâm tới", value = 'Tất cả') | |
chat_interface = gr.ChatInterface(get_answer, | |
textbox=gr.Textbox(placeholder="Đặt câu hỏi tại đây", | |
container=False, | |
autoscroll=True, | |
scale=7), | |
type="messages", | |
# textbox=prompt, | |
# additional_inputs=[category1, category2] | |
) | |
# playwright_path = get_driver_dir() | |
if __name__ == "__main__": | |
demo.launch() | |
# demo.launch() |