VladimirVorobev commited on
Commit
96b116a
·
1 Parent(s): ed0720f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3')
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4
+ import chromadb
5
+ import torch
6
+ import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
8
+
9
+ from langchain.llms import OpenAI, GigaChat
10
+ from langchain.chains import LLMChain
11
+ from langchain.prompts import PromptTemplate
12
+
13
+ chatgpt = OpenAI(
14
+ api_key='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',
15
+ )
16
+
17
+ gigachat = GigaChat(
18
+ credentials='Y2Y4Yjk5ODUtNThmMC00ODdjLTk5ODItNDdmYzhmNDdmNzE0OjQ5Y2RjNTVkLWFmMGQtNGJlYy04OGNiLTI1Yzc3MmJkMzYwYw==',
19
+ scope='GIGACHAT_API_PERS',
20
+ verify_ssl_certs=False
21
+ )
22
+
23
+ llms = {
24
+ 'ChatGPT': chatgpt,
25
+ 'GigaChat': gigachat,
26
+ }
27
+
28
+ # задаем формат вывода модели
29
+ answer_task_types = {
30
+ 'Развернутый ответ': 'Ответь достаточно подробно, но не используй ничего лишнего.',
31
+ 'Только цифры штрафа': 'Ответь в виде <количество> рублей или <диапазон> рублей, и больше ничего не пиши.'
32
+ }
33
+
34
+ # проверяем с помощью LLM валидность запроса, исключая обработку бессмысленного входа
35
+ validity_template = '{query}\n\nЭто валидный запрос? Ответь да или нет, больше ничего не пиши.'
36
+ validity_prompt = PromptTemplate(template=validity_template, input_variables=['query'])
37
+
38
+ # получаем ответ модели на запрос, используем его для более качественного поиска Retriever'ом и Cross-Encoder'ом
39
+ query_template = '{query} Ответь текстом, похожим на закон, не пиши ничего лишнего. Не используй в ответе слово КоАП РФ. Не используй слово "Россия".'
40
+ query_prompt = PromptTemplate(template=query_template, input_variables=['query'])
41
+
42
+ # просим LLM выбрать один из 3 фрагментов текста, выбранных поисковыми моделями, где по мнению модели есть ответ. Если ответа нет, модель нам об этом сообщает
43
+ choose_answer_template = '1. {text_1}\n\n2. {text_2}\n\n3. {text_3}\n\nЗадание: выбери из перечисленных выше отрывков тот, где есть ответ на вопрос: "{query}". В качестве ответа напиши только номер 1, 2 или 3 и все. Если в данных отрывках нет ответа, то напиши "Нет ответа".'
44
+ choose_answer_prompt = PromptTemplate(template=choose_answer_template, input_variables=['text_1', 'text_2', 'text_3', 'query'])
45
+
46
+ # просим LLM ответить на вопрос, опираясь на найденный фрагмент, и в нужном формате, или сообщить, что ответа все-таки нет
47
+ answer_template = '{text}\n\nЗадание: ответь на вопрос по тексту: "{query}". {answer_type} Если в данном тексте нет ответа, то напиши "Нет ответа".'
48
+ answer_prompt = PromptTemplate(template=answer_template, input_variables=['text', 'query', 'answer_type'])
49
+
50
+ client = chromadb.PersistentClient(path='db')
51
+ collection = client.get_collection(name="administrative_codex")
52
+
53
+ retriever_checkpoint = 'sentence-transformers/LaBSE'
54
+ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_checkpoint)
55
+ retriever_model = AutoModel.from_pretrained(retriever_checkpoint)
56
+
57
+ cross_encoder_checkpoint = 'jeffwan/mmarco-mMiniLMv2-L12-H384-v1'
58
+ cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_checkpoint)
59
+ cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_checkpoint)
60
+ cross_encoder = pipeline('text-classification', model=cross_encoder_model, tokenizer=cross_encoder_tokenizer)
61
+
62
+
63
+ def encode(docs):
64
+ if type(docs) == str:
65
+ docs = [docs]
66
+
67
+ encoded_input = retriever_tokenizer(
68
+ docs,
69
+ padding=True,
70
+ truncation=True,
71
+ max_length=512,
72
+ return_tensors='pt'
73
+ )
74
+
75
+ with torch.no_grad():
76
+ model_output = retriever_model(**encoded_input)
77
+
78
+ embeddings = model_output.pooler_output
79
+ embeddings = torch.nn.functional.normalize(embeddings)
80
+ return embeddings.detach().cpu().tolist()
81
+
82
+
83
+ def re_rank(sentence, docs):
84
+ return [res['score'] for res in cross_encoder([{'text': sentence, 'text_pair': doc} for doc in docs], max_length=512, truncation=True)]
85
+
86
+
87
+ def update_query_with_llm(query, llm_type, use_llm_for_retriever):
88
+ if llm_type == 'Без LLM' or not use_llm_for_retriever:
89
+ return query
90
+
91
+ llm_chain = LLMChain(prompt=query_prompt, llm=llms[llm_type])
92
+ return f'{query} {llm_chain.run(query).strip()}'
93
+
94
+
95
+ def answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type):
96
+ if llm_type == 'Без LLM':
97
+ answer, metadata, re_ranker_score = re_ranked_res[0]
98
+ else:
99
+ llm_chain = LLMChain(prompt=choose_answer_prompt, llm=llms[llm_type])
100
+ llm_chain_dict = {f'text_{i}': res[0] for i, res in enumerate(re_ranked_res, start=1)}
101
+ llm_chain_dict['query'] = query
102
+
103
+ llm_res = llm_chain.run(llm_chain_dict).strip()
104
+
105
+ if 'нет ответа' in llm_res.lower() or not llm_res[0].isnumeric():
106
+ return 'Нет ответа', '', ''
107
+
108
+ most_suitable_text, metadata, re_ranker_score = re_ranked_res[int(llm_res[0]) - 1]
109
+
110
+ llm_chain = LLMChain(prompt=answer_prompt, llm=llms[llm_type])
111
+ answer = llm_chain.run({'text': most_suitable_text, 'query': query, 'answer_type': llm_answer_type}).strip()
112
+
113
+ if 'нет ответа' in answer.lower():
114
+ answer = 'Нет ответа'
115
+
116
+ # если LLM сначала выбрала фрагмент, где есть ответ, а потом не смогла ответить на вопрос (что бывает редко), то все равно порекомендуем пользователю обратиться к норме
117
+ law_norm = f"{'Попробуйте обратиться к этому источнику: ' if answer == 'Нет ответа' else ''}{metadata['article']} {metadata['point']} {metadata['doc']}"
118
+ return answer, law_norm, re_ranker_score
119
+
120
+
121
+ def check_request_validity(func):
122
+ def wrapper(
123
+ query,
124
+ llm_type,
125
+ llm_answer_type,
126
+ use_llm_for_retriever,
127
+ use_llm_for_request_validation
128
+ ):
129
+ query = query.strip()
130
+
131
+ if not query:
132
+ return 'Невалидный запрос', '', ''
133
+
134
+ if llm_type == 'Без LLM' or not use_llm_for_request_validation:
135
+ return func(query, llm_type, llm_answer_type, use_llm_for_retriever)
136
+
137
+ llm_chain = LLMChain(prompt=validity_prompt, llm=llms[llm_type])
138
+
139
+ if 'нет' in llm_chain.run(query).lower():
140
+ return 'Невалидный запрос', '', ''
141
+
142
+ return func(query, llm_type, llm_answer_type, use_llm_for_retriever)
143
+
144
+ return wrapper
145
+
146
+
147
+ @check_request_validity
148
+ def fn(
149
+ query,
150
+ llm_type,
151
+ llm_answer_type,
152
+ use_llm_for_retriever
153
+ ):
154
+ # обогатим запрос с помощью LLM, чтобы поисковым моделям было проще найти нужный фрагмент с ответом
155
+ retriever_ranker_query = update_query_with_llm(query, llm_type, use_llm_for_retriever)
156
+
157
+ # Retriever-поиск по базе данных
158
+ retriever_res = collection.query(
159
+ query_embeddings=encode(retriever_ranker_query),
160
+ n_results=10,
161
+ )
162
+
163
+ top_k_docs = retriever_res['documents'][0]
164
+
165
+ # re-ranking с помощью Cross-Encoder'а и отбор лучших кандидатов
166
+ re_rank_scores = re_rank(retriever_ranker_query, top_k_docs)
167
+ re_ranked_res = sorted(
168
+ [[doc, meta, score] for doc, meta, score in zip(retriever_res['documents'][0], retriever_res['metadatas'][0], re_rank_scores)],
169
+ key=lambda x: x[-1],
170
+ reverse=True,
171
+ )[:3]
172
+
173
+ # поиск ответа и нормы с помощью LLM
174
+ return answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type)
175
+
176
+
177
+ demo = gr.Interface(
178
+ fn=fn,
179
+ inputs=[
180
+ gr.Textbox(lines=3, label='Запрос', placeholder='Введите запрос'),
181
+ gr.Dropdown(label='Тип LLM', choices=['ChatGPT', 'GigaChat', 'Без LLM'], value='ChatGPT'),
182
+ gr.Dropdown(label='Тип итогового ответа LLM', choices=['Только цифры штрафа', 'Развернутый ответ'], value='Только цифры штрафа'),
183
+ gr.Checkbox(label="Использовать LLM для Retriever'а", value=True, info="При использовании LLM для Retriever'а ко входному запросу будет добавляться промежуточный ответ LLM на запрос. Это способствует повышению качества поиска ответа."),
184
+ gr.Checkbox(label="Использовать LLM для проверки валидности запроса", value=False)
185
+ ],
186
+ outputs=[
187
+ gr.Textbox(label='Ответ'),
188
+ gr.Textbox(label='Норма'),
189
+ gr.Textbox(label="Уверенность Cross-Encoder'а"),
190
+ ],
191
+ )
192
+
193
+ demo.launch()