Commit
·
96b116a
1
Parent(s):
ed0720f
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__import__('pysqlite3')
|
2 |
+
import sys
|
3 |
+
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
4 |
+
import chromadb
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
|
8 |
+
|
9 |
+
from langchain.llms import OpenAI, GigaChat
|
10 |
+
from langchain.chains import LLMChain
|
11 |
+
from langchain.prompts import PromptTemplate
|
12 |
+
|
13 |
+
chatgpt = OpenAI(
|
14 |
+
api_key='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',
|
15 |
+
)
|
16 |
+
|
17 |
+
gigachat = GigaChat(
|
18 |
+
credentials='Y2Y4Yjk5ODUtNThmMC00ODdjLTk5ODItNDdmYzhmNDdmNzE0OjQ5Y2RjNTVkLWFmMGQtNGJlYy04OGNiLTI1Yzc3MmJkMzYwYw==',
|
19 |
+
scope='GIGACHAT_API_PERS',
|
20 |
+
verify_ssl_certs=False
|
21 |
+
)
|
22 |
+
|
23 |
+
llms = {
|
24 |
+
'ChatGPT': chatgpt,
|
25 |
+
'GigaChat': gigachat,
|
26 |
+
}
|
27 |
+
|
28 |
+
# задаем формат вывода модели
|
29 |
+
answer_task_types = {
|
30 |
+
'Развернутый ответ': 'Ответь достаточно подробно, но не используй ничего лишнего.',
|
31 |
+
'Только цифры штрафа': 'Ответь в виде <количество> рублей или <диапазон> рублей, и больше ничего не пиши.'
|
32 |
+
}
|
33 |
+
|
34 |
+
# проверяем с помощью LLM валидность запроса, исключая обработку бессмысленного входа
|
35 |
+
validity_template = '{query}\n\nЭто валидный запрос? Ответь да или нет, больше ничего не пиши.'
|
36 |
+
validity_prompt = PromptTemplate(template=validity_template, input_variables=['query'])
|
37 |
+
|
38 |
+
# получаем ответ модели на запрос, используем его для более качественного поиска Retriever'ом и Cross-Encoder'ом
|
39 |
+
query_template = '{query} Ответь текстом, похожим на закон, не пиши ничего лишнего. Не используй в ответе слово КоАП РФ. Не используй слово "Россия".'
|
40 |
+
query_prompt = PromptTemplate(template=query_template, input_variables=['query'])
|
41 |
+
|
42 |
+
# просим LLM выбрать один из 3 фрагментов текста, выбранных поисковыми моделями, где по мнению модели есть ответ. Если ответа нет, модель нам об этом сообщает
|
43 |
+
choose_answer_template = '1. {text_1}\n\n2. {text_2}\n\n3. {text_3}\n\nЗадание: выбери из перечисленных выше отрывков тот, где есть ответ на вопрос: "{query}". В качестве ответа напиши только номер 1, 2 или 3 и все. Если в данных отрывках нет ответа, то напиши "Нет ответа".'
|
44 |
+
choose_answer_prompt = PromptTemplate(template=choose_answer_template, input_variables=['text_1', 'text_2', 'text_3', 'query'])
|
45 |
+
|
46 |
+
# просим LLM ответить на вопрос, опираясь на найденный фрагмент, и в нужном формате, или сообщить, что ответа все-таки нет
|
47 |
+
answer_template = '{text}\n\nЗадание: ответь на вопрос по тексту: "{query}". {answer_type} Если в данном тексте нет ответа, то напиши "Нет ответа".'
|
48 |
+
answer_prompt = PromptTemplate(template=answer_template, input_variables=['text', 'query', 'answer_type'])
|
49 |
+
|
50 |
+
client = chromadb.PersistentClient(path='db')
|
51 |
+
collection = client.get_collection(name="administrative_codex")
|
52 |
+
|
53 |
+
retriever_checkpoint = 'sentence-transformers/LaBSE'
|
54 |
+
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_checkpoint)
|
55 |
+
retriever_model = AutoModel.from_pretrained(retriever_checkpoint)
|
56 |
+
|
57 |
+
cross_encoder_checkpoint = 'jeffwan/mmarco-mMiniLMv2-L12-H384-v1'
|
58 |
+
cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_checkpoint)
|
59 |
+
cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_checkpoint)
|
60 |
+
cross_encoder = pipeline('text-classification', model=cross_encoder_model, tokenizer=cross_encoder_tokenizer)
|
61 |
+
|
62 |
+
|
63 |
+
def encode(docs):
|
64 |
+
if type(docs) == str:
|
65 |
+
docs = [docs]
|
66 |
+
|
67 |
+
encoded_input = retriever_tokenizer(
|
68 |
+
docs,
|
69 |
+
padding=True,
|
70 |
+
truncation=True,
|
71 |
+
max_length=512,
|
72 |
+
return_tensors='pt'
|
73 |
+
)
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
model_output = retriever_model(**encoded_input)
|
77 |
+
|
78 |
+
embeddings = model_output.pooler_output
|
79 |
+
embeddings = torch.nn.functional.normalize(embeddings)
|
80 |
+
return embeddings.detach().cpu().tolist()
|
81 |
+
|
82 |
+
|
83 |
+
def re_rank(sentence, docs):
|
84 |
+
return [res['score'] for res in cross_encoder([{'text': sentence, 'text_pair': doc} for doc in docs], max_length=512, truncation=True)]
|
85 |
+
|
86 |
+
|
87 |
+
def update_query_with_llm(query, llm_type, use_llm_for_retriever):
|
88 |
+
if llm_type == 'Без LLM' or not use_llm_for_retriever:
|
89 |
+
return query
|
90 |
+
|
91 |
+
llm_chain = LLMChain(prompt=query_prompt, llm=llms[llm_type])
|
92 |
+
return f'{query} {llm_chain.run(query).strip()}'
|
93 |
+
|
94 |
+
|
95 |
+
def answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type):
|
96 |
+
if llm_type == 'Без LLM':
|
97 |
+
answer, metadata, re_ranker_score = re_ranked_res[0]
|
98 |
+
else:
|
99 |
+
llm_chain = LLMChain(prompt=choose_answer_prompt, llm=llms[llm_type])
|
100 |
+
llm_chain_dict = {f'text_{i}': res[0] for i, res in enumerate(re_ranked_res, start=1)}
|
101 |
+
llm_chain_dict['query'] = query
|
102 |
+
|
103 |
+
llm_res = llm_chain.run(llm_chain_dict).strip()
|
104 |
+
|
105 |
+
if 'нет ответа' in llm_res.lower() or not llm_res[0].isnumeric():
|
106 |
+
return 'Нет ответа', '', ''
|
107 |
+
|
108 |
+
most_suitable_text, metadata, re_ranker_score = re_ranked_res[int(llm_res[0]) - 1]
|
109 |
+
|
110 |
+
llm_chain = LLMChain(prompt=answer_prompt, llm=llms[llm_type])
|
111 |
+
answer = llm_chain.run({'text': most_suitable_text, 'query': query, 'answer_type': llm_answer_type}).strip()
|
112 |
+
|
113 |
+
if 'нет ответа' in answer.lower():
|
114 |
+
answer = 'Нет ответа'
|
115 |
+
|
116 |
+
# если LLM сначала выбрала фрагмент, где есть ответ, а потом не смогла ответить на вопрос (что бывает редко), то все равно порекомендуем пользователю обратиться к норме
|
117 |
+
law_norm = f"{'Попробуйте обратиться к этому источнику: ' if answer == 'Нет ответа' else ''}{metadata['article']} {metadata['point']} {metadata['doc']}"
|
118 |
+
return answer, law_norm, re_ranker_score
|
119 |
+
|
120 |
+
|
121 |
+
def check_request_validity(func):
|
122 |
+
def wrapper(
|
123 |
+
query,
|
124 |
+
llm_type,
|
125 |
+
llm_answer_type,
|
126 |
+
use_llm_for_retriever,
|
127 |
+
use_llm_for_request_validation
|
128 |
+
):
|
129 |
+
query = query.strip()
|
130 |
+
|
131 |
+
if not query:
|
132 |
+
return 'Невалидный запрос', '', ''
|
133 |
+
|
134 |
+
if llm_type == 'Без LLM' or not use_llm_for_request_validation:
|
135 |
+
return func(query, llm_type, llm_answer_type, use_llm_for_retriever)
|
136 |
+
|
137 |
+
llm_chain = LLMChain(prompt=validity_prompt, llm=llms[llm_type])
|
138 |
+
|
139 |
+
if 'нет' in llm_chain.run(query).lower():
|
140 |
+
return 'Невалидный запрос', '', ''
|
141 |
+
|
142 |
+
return func(query, llm_type, llm_answer_type, use_llm_for_retriever)
|
143 |
+
|
144 |
+
return wrapper
|
145 |
+
|
146 |
+
|
147 |
+
@check_request_validity
|
148 |
+
def fn(
|
149 |
+
query,
|
150 |
+
llm_type,
|
151 |
+
llm_answer_type,
|
152 |
+
use_llm_for_retriever
|
153 |
+
):
|
154 |
+
# обогатим запрос с помощью LLM, чтобы поисковым моделям было проще найти нужный фрагмент с ответом
|
155 |
+
retriever_ranker_query = update_query_with_llm(query, llm_type, use_llm_for_retriever)
|
156 |
+
|
157 |
+
# Retriever-поиск по базе данных
|
158 |
+
retriever_res = collection.query(
|
159 |
+
query_embeddings=encode(retriever_ranker_query),
|
160 |
+
n_results=10,
|
161 |
+
)
|
162 |
+
|
163 |
+
top_k_docs = retriever_res['documents'][0]
|
164 |
+
|
165 |
+
# re-ranking с помощью Cross-Encoder'а и отбор лучших кандидатов
|
166 |
+
re_rank_scores = re_rank(retriever_ranker_query, top_k_docs)
|
167 |
+
re_ranked_res = sorted(
|
168 |
+
[[doc, meta, score] for doc, meta, score in zip(retriever_res['documents'][0], retriever_res['metadatas'][0], re_rank_scores)],
|
169 |
+
key=lambda x: x[-1],
|
170 |
+
reverse=True,
|
171 |
+
)[:3]
|
172 |
+
|
173 |
+
# поиск ответа и нормы с помощью LLM
|
174 |
+
return answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type)
|
175 |
+
|
176 |
+
|
177 |
+
demo = gr.Interface(
|
178 |
+
fn=fn,
|
179 |
+
inputs=[
|
180 |
+
gr.Textbox(lines=3, label='Запрос', placeholder='Введите запрос'),
|
181 |
+
gr.Dropdown(label='Тип LLM', choices=['ChatGPT', 'GigaChat', 'Без LLM'], value='ChatGPT'),
|
182 |
+
gr.Dropdown(label='Тип итогового ответа LLM', choices=['Только цифры штрафа', 'Развернутый ответ'], value='Только цифры штрафа'),
|
183 |
+
gr.Checkbox(label="Использовать LLM для Retriever'а", value=True, info="При использовании LLM для Retriever'а ко входному запросу будет добавляться промежуточный ответ LLM на запрос. Это способствует повышению качества поиска ответа."),
|
184 |
+
gr.Checkbox(label="Использовать LLM для проверки валидности запроса", value=False)
|
185 |
+
],
|
186 |
+
outputs=[
|
187 |
+
gr.Textbox(label='Ответ'),
|
188 |
+
gr.Textbox(label='Норма'),
|
189 |
+
gr.Textbox(label="Уверенность Cross-Encoder'а"),
|
190 |
+
],
|
191 |
+
)
|
192 |
+
|
193 |
+
demo.launch()
|