File size: 4,033 Bytes
587c522
ba964ad
dfabed7
d802f7b
e0f269f
d802f7b
e0f269f
07b2dc9
c104d72
e0f269f
 
245e246
e0f269f
d802f7b
c56e053
62319f5
2bd39e5
587c522
 
 
dfabed7
587c522
dfabed7
587c522
 
c56e053
62319f5
c56e053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587c522
 
 
 
 
 
 
 
 
 
 
 
 
68d8ac5
587c522
 
 
c104d72
587c522
68d8ac5
62319f5
3d50ef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import subprocess
import gradio as gr
import json
from tqdm import tqdm
from langchain_community.vectorstores import FAISS
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import google.generativeai as genai
# from playwright._impl._driver import get_driver_dir

from helpers import (
    list_docx_files, get_splits, get_json_splits_only, prompt_order, log_message
)

from file_loader import get_vectorstore
# import asyncio

if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs"
key = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs"

###

# Cấu hình API key cho Google GenAI
genai.configure(api_key=key)

vectorstore = get_vectorstore()
# Define the augment_prompt function
def augment_prompt(query: str, k: int = 10):
    queries = []
    queries.append(query)

    retriever = vectorstore.as_retriever(search_kwargs={"k": k})
    results = retriever.invoke(query)

    if results:
        source_knowledge = "\n\n".join([doc.page_content for doc in results])
        return f"""Using the contexts below, answer the query.
Contexts:
{source_knowledge}
"""
    else:
        return f"No relevant context found.\n."

def get_answer(query, queries_list=None):
    if queries_list is None:
        queries_list = []

    messages = [
    {"role": "user", "parts": [{"text": "IMPORTANT: You are a super energetic, helpful, polite, Vietnamese-speaking assistant. If you can not see the answer in contexts, try to search it up online by yourself but remember to give the source."}]},
    {"role": "user", "parts": [{"text": augment_prompt(query)}]}
]
#     bonus = '''
# Bạn tham kháo thêm các nguồn thông tin tại:
# Trang thông tin điện tử: https://neu.edu.vn ; https://daotao.neu.edu.vn
# Trang mạng xã hội có thông tin tuyển sinh: https://www.facebook.com/ktqdNEU ; https://www.facebook.com/tvtsneu ;
# Email tuyển sinh: [email protected]
# Số điện thoại tuyển sinh: 0888.128.558
#   '''

    queries_list.append(query)
    queries = {"role": "user", "parts": [{"text": prompt_order(queries_list)}]}
    messages_with_queries = messages.copy()
    messages_with_queries.append(queries)
    # messages_with_queries.insert(0, queries)

  # Configure API key
    genai.configure(api_key=key)

  # Initialize the Gemini model
    model = genai.GenerativeModel("gemini-2.0-flash")

    response = model.generate_content(contents=messages_with_queries, stream=True)
    response_text = ""

    for chunk in response:
        response_text += chunk.text
        yield response_text

    messages.append({"role": "model", "parts": [{"text": response_text}]})

        # user_feedback = yield "\nNhập phản hồi của bạn (hoặc nhập 'q' để thoát): "
        # if user_feedback.lower() == "q":
        #     break

        # messages.append({"role": "user", "parts": [{"text": query}]})

    log_message(messages)

institutions = ['Tất cả'] + ['Trường Công Nghệ']
categories = ['Tất cả'] + ['Đề án', 'Chương trình đào tạo']

with gr.Blocks() as demo:
    with gr.Row():
        category1 = gr.Dropdown(choices = institutions, label="Trường", value = 'Tất cả')
        category2 = gr.Dropdown(choices = categories, label="Bạn quan tâm tới", value = 'Tất cả')

    chat_interface = gr.ChatInterface(get_answer,
                                      textbox=gr.Textbox(placeholder="Đặt câu hỏi tại đây",
                                                        container=False,
                                                        autoscroll=True,
                                                        scale=7),
                                      type="messages",
                                      # textbox=prompt,
                                      # additional_inputs=[category1, category2]
                                      )
# playwright_path = get_driver_dir()

if __name__ == "__main__":
    demo.launch()    
    # demo.launch()