VladimirVorobev commited on
Commit
ed0720f
·
1 Parent(s): e25e0f3

Upload 43 files

Browse files
Files changed (44) hide show
  1. .gitattributes +1 -0
  2. app.ipynb +274 -0
  3. db/97372335-8399-4edb-90ed-bcaf83a48591/data_level0.bin +3 -0
  4. db/97372335-8399-4edb-90ed-bcaf83a48591/header.bin +3 -0
  5. db/97372335-8399-4edb-90ed-bcaf83a48591/index_metadata.pickle +3 -0
  6. db/97372335-8399-4edb-90ed-bcaf83a48591/length.bin +3 -0
  7. db/97372335-8399-4edb-90ed-bcaf83a48591/link_lists.bin +3 -0
  8. db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/data_level0.bin +3 -0
  9. db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/header.bin +3 -0
  10. db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/index_metadata.pickle +3 -0
  11. db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/length.bin +3 -0
  12. db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/link_lists.bin +3 -0
  13. db/aa5bd886-4fdf-4d66-965d-18ffddafee21/data_level0.bin +3 -0
  14. db/aa5bd886-4fdf-4d66-965d-18ffddafee21/header.bin +3 -0
  15. db/aa5bd886-4fdf-4d66-965d-18ffddafee21/index_metadata.pickle +3 -0
  16. db/aa5bd886-4fdf-4d66-965d-18ffddafee21/length.bin +3 -0
  17. db/aa5bd886-4fdf-4d66-965d-18ffddafee21/link_lists.bin +3 -0
  18. db/b402ad0c-87fe-4eff-b65e-58777a7face5/data_level0.bin +3 -0
  19. db/b402ad0c-87fe-4eff-b65e-58777a7face5/header.bin +3 -0
  20. db/b402ad0c-87fe-4eff-b65e-58777a7face5/index_metadata.pickle +3 -0
  21. db/b402ad0c-87fe-4eff-b65e-58777a7face5/length.bin +3 -0
  22. db/b402ad0c-87fe-4eff-b65e-58777a7face5/link_lists.bin +3 -0
  23. db/chroma.sqlite3 +3 -0
  24. db/e90dea15-53e7-4b41-b502-0ab854ae08d8/data_level0.bin +3 -0
  25. db/e90dea15-53e7-4b41-b502-0ab854ae08d8/header.bin +3 -0
  26. db/e90dea15-53e7-4b41-b502-0ab854ae08d8/index_metadata.pickle +3 -0
  27. db/e90dea15-53e7-4b41-b502-0ab854ae08d8/length.bin +3 -0
  28. db/e90dea15-53e7-4b41-b502-0ab854ae08d8/link_lists.bin +3 -0
  29. docs/КоАП РФ.txt +0 -0
  30. examples/1.png +0 -0
  31. examples/10.png +0 -0
  32. examples/11.png +0 -0
  33. examples/2.png +0 -0
  34. examples/3.png +0 -0
  35. examples/4.png +0 -0
  36. examples/5.png +0 -0
  37. examples/6.png +0 -0
  38. examples/7.png +0 -0
  39. examples/8.png +0 -0
  40. examples/9.png +0 -0
  41. generated_additional_llm_answer.csv +0 -0
  42. generated_questions.csv +0 -0
  43. preprocess_doc.ipynb +220 -0
  44. question_generation.ipynb +179 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
app.ipynb ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "276ff8dc-703d-4966-918b-983c592e7938",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "__import__('pysqlite3')\n",
11
+ "import sys\n",
12
+ "sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')\n",
13
+ "import chromadb\n",
14
+ "import torch\n",
15
+ "import gradio as gr\n",
16
+ "from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline\n",
17
+ "\n",
18
+ "from langchain.llms import OpenAI, GigaChat\n",
19
+ "from langchain.chains import LLMChain\n",
20
+ "from langchain.prompts import PromptTemplate\n",
21
+ "\n",
22
+ "chatgpt = OpenAI(\n",
23
+ " api_key='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',\n",
24
+ ")\n",
25
+ "\n",
26
+ "gigachat = GigaChat(\n",
27
+ " credentials='Y2Y4Yjk5ODUtNThmMC00ODdjLTk5ODItNDdmYzhmNDdmNzE0OjQ5Y2RjNTVkLWFmMGQtNGJlYy04OGNiLTI1Yzc3MmJkMzYwYw==',\n",
28
+ " scope='GIGACHAT_API_PERS',\n",
29
+ " verify_ssl_certs=False\n",
30
+ ")\n",
31
+ "\n",
32
+ "llms = {\n",
33
+ " 'ChatGPT': chatgpt,\n",
34
+ " 'GigaChat': gigachat,\n",
35
+ "}\n",
36
+ "\n",
37
+ "# задаем формат вывода модели\n",
38
+ "answer_task_types = {\n",
39
+ " 'Развернутый ответ': 'Ответь достаточно подробно, но не используй ничего лишнего.',\n",
40
+ " 'Только цифры штрафа': 'Ответь в виде <количество> рублей или <диапазон> рублей, и больше ничего не пиши.'\n",
41
+ "}\n",
42
+ "\n",
43
+ "# проверяем с помощью LLM валидность запроса, исключая обработку бессмысленного входа\n",
44
+ "validity_template = '{query}\\n\\nЭто валидный запрос? Ответь да или нет, больше ничего не пиши.'\n",
45
+ "validity_prompt = PromptTemplate(template=validity_template, input_variables=['query'])\n",
46
+ "\n",
47
+ "# получаем ответ модели на запрос, используем его для более качественного поиска Retriever'ом и Cross-Encoder'ом\n",
48
+ "query_template = '{query} Ответь текстом, похожим на закон, не пиши ничего лишнего. Не используй в ответе слово КоАП РФ. Не используй слово \"Россия\".'\n",
49
+ "query_prompt = PromptTemplate(template=query_template, input_variables=['query'])\n",
50
+ "\n",
51
+ "# просим LLM выбрать один из 3 фрагментов текста, выбранных поисковыми моделями, где по мнению модели есть ответ. Если ответа нет, модель нам об этом сообщает\n",
52
+ "choose_answer_template = '1. {text_1}\\n\\n2. {text_2}\\n\\n3. {text_3}\\n\\nЗадание: выбери из перечисленных выше отрывков тот, где есть ответ на вопрос: \"{query}\". В качестве ответа напиши только номер 1, 2 или 3 и все. Если в данных отрывках нет ответа, то напиши \"Нет ответа\".'\n",
53
+ "choose_answer_prompt = PromptTemplate(template=choose_answer_template, input_variables=['text_1', 'text_2', 'text_3', 'query'])\n",
54
+ "\n",
55
+ "# просим LLM ответить на вопрос, опираясь на найденный фрагмент, и в нужном формате, или сообщить, что ответа все-таки нет\n",
56
+ "answer_template = '{text}\\n\\nЗадание: ответь на вопрос по тексту: \"{query}\". {answer_type} Если в данном тексте нет ответа, то напиши \"Нет ответа\".'\n",
57
+ "answer_prompt = PromptTemplate(template=answer_template, input_variables=['text', 'query', 'answer_type'])\n",
58
+ "\n",
59
+ "client = chromadb.PersistentClient(path='db')\n",
60
+ "collection = client.get_collection(name=\"administrative_codex\")\n",
61
+ "\n",
62
+ "retriever_checkpoint = 'sentence-transformers/LaBSE'\n",
63
+ "retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_checkpoint)\n",
64
+ "retriever_model = AutoModel.from_pretrained(retriever_checkpoint)\n",
65
+ "\n",
66
+ "cross_encoder_checkpoint = 'jeffwan/mmarco-mMiniLMv2-L12-H384-v1'\n",
67
+ "cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_checkpoint)\n",
68
+ "cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_checkpoint)\n",
69
+ "cross_encoder = pipeline('text-classification', model=cross_encoder_model, tokenizer=cross_encoder_tokenizer)"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 6,
75
+ "id": "ca60b616-97f4-4d01-b0f5-4a994d11e216",
76
+ "metadata": {
77
+ "scrolled": true
78
+ },
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "Running on local URL: http://127.0.0.1:7864\n",
85
+ "\n",
86
+ "To create a public link, set `share=True` in `launch()`.\n"
87
+ ]
88
+ },
89
+ {
90
+ "data": {
91
+ "text/html": [
92
+ "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
93
+ ],
94
+ "text/plain": [
95
+ "<IPython.core.display.HTML object>"
96
+ ]
97
+ },
98
+ "metadata": {},
99
+ "output_type": "display_data"
100
+ },
101
+ {
102
+ "data": {
103
+ "text/plain": []
104
+ },
105
+ "execution_count": 6,
106
+ "metadata": {},
107
+ "output_type": "execute_result"
108
+ }
109
+ ],
110
+ "source": [
111
+ "def encode(docs):\n",
112
+ " if type(docs) == str:\n",
113
+ " docs = [docs]\n",
114
+ "\n",
115
+ " encoded_input = retriever_tokenizer(\n",
116
+ " docs,\n",
117
+ " padding=True,\n",
118
+ " truncation=True,\n",
119
+ " max_length=512,\n",
120
+ " return_tensors='pt'\n",
121
+ " )\n",
122
+ " \n",
123
+ " with torch.no_grad():\n",
124
+ " model_output = retriever_model(**encoded_input)\n",
125
+ " \n",
126
+ " embeddings = model_output.pooler_output\n",
127
+ " embeddings = torch.nn.functional.normalize(embeddings)\n",
128
+ " return embeddings.detach().cpu().tolist()\n",
129
+ "\n",
130
+ "\n",
131
+ "def re_rank(sentence, docs):\n",
132
+ " return [res['score'] for res in cross_encoder([{'text': sentence, 'text_pair': doc} for doc in docs], max_length=512, truncation=True)]\n",
133
+ "\n",
134
+ "\n",
135
+ "def update_query_with_llm(query, llm_type, use_llm_for_retriever):\n",
136
+ " if llm_type == 'Без LLM' or not use_llm_for_retriever:\n",
137
+ " return query\n",
138
+ " \n",
139
+ " llm_chain = LLMChain(prompt=query_prompt, llm=llms[llm_type])\n",
140
+ " return f'{query} {llm_chain.run(query).strip()}'\n",
141
+ "\n",
142
+ "\n",
143
+ "def answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type):\n",
144
+ " if llm_type == 'Без LLM':\n",
145
+ " answer, metadata, re_ranker_score = re_ranked_res[0]\n",
146
+ " else:\n",
147
+ " llm_chain = LLMChain(prompt=choose_answer_prompt, llm=llms[llm_type])\n",
148
+ " llm_chain_dict = {f'text_{i}': res[0] for i, res in enumerate(re_ranked_res, start=1)}\n",
149
+ " llm_chain_dict['query'] = query\n",
150
+ "\n",
151
+ " llm_res = llm_chain.run(llm_chain_dict).strip()\n",
152
+ "\n",
153
+ " if 'нет ответа' in llm_res.lower() or not llm_res[0].isnumeric():\n",
154
+ " return 'Нет ответа', '', ''\n",
155
+ " \n",
156
+ " most_suitable_text, metadata, re_ranker_score = re_ranked_res[int(llm_res[0]) - 1]\n",
157
+ "\n",
158
+ " llm_chain = LLMChain(prompt=answer_prompt, llm=llms[llm_type])\n",
159
+ " answer = llm_chain.run({'text': most_suitable_text, 'query': query, 'answer_type': llm_answer_type}).strip()\n",
160
+ "\n",
161
+ " if 'нет ответа' in answer.lower():\n",
162
+ " answer = 'Нет ответа'\n",
163
+ "\n",
164
+ " # если LLM сначала выбрала фрагмент, где есть ответ, а потом не смогла ответить на вопрос (что бывает редко), то все равно порекомендуем пользователю обратиться к норме\n",
165
+ " law_norm = f\"{'Попробуйте обратиться к этому источнику: ' if answer == 'Нет ответа' else ''}{metadata['article']} {metadata['point']} {metadata['doc']}\"\n",
166
+ " return answer, law_norm, re_ranker_score\n",
167
+ "\n",
168
+ "\n",
169
+ "def check_request_validity(func):\n",
170
+ " def wrapper(\n",
171
+ " query,\n",
172
+ " llm_type,\n",
173
+ " llm_answer_type,\n",
174
+ " use_llm_for_retriever,\n",
175
+ " use_llm_for_request_validation\n",
176
+ " ):\n",
177
+ " query = query.strip()\n",
178
+ " \n",
179
+ " if not query:\n",
180
+ " return 'Невалидный запрос', '', ''\n",
181
+ " \n",
182
+ " if llm_type == 'Без LLM' or not use_llm_for_request_validation:\n",
183
+ " return func(query, llm_type, llm_answer_type, use_llm_for_retriever)\n",
184
+ " \n",
185
+ " llm_chain = LLMChain(prompt=validity_prompt, llm=llms[llm_type])\n",
186
+ "\n",
187
+ " if 'нет' in llm_chain.run(query).lower():\n",
188
+ " return 'Невалидный запрос', '', ''\n",
189
+ " \n",
190
+ " return func(query, llm_type, llm_answer_type, use_llm_for_retriever)\n",
191
+ " \n",
192
+ " return wrapper\n",
193
+ "\n",
194
+ "\n",
195
+ "@check_request_validity\n",
196
+ "def fn(\n",
197
+ " query,\n",
198
+ " llm_type,\n",
199
+ " llm_answer_type,\n",
200
+ " use_llm_for_retriever\n",
201
+ "):\n",
202
+ " # обогатим запрос с помощью LLM, чтобы поис��овым моделям было проще найти нужный фрагмент с ответом\n",
203
+ " retriever_ranker_query = update_query_with_llm(query, llm_type, use_llm_for_retriever)\n",
204
+ "\n",
205
+ " # Retriever-поиск по базе данных\n",
206
+ " retriever_res = collection.query(\n",
207
+ " query_embeddings=encode(retriever_ranker_query),\n",
208
+ " n_results=10,\n",
209
+ " )\n",
210
+ "\n",
211
+ " top_k_docs = retriever_res['documents'][0]\n",
212
+ "\n",
213
+ " # re-ranking с помощью Cross-Encoder'а и отбор лучших кандидатов\n",
214
+ " re_rank_scores = re_rank(retriever_ranker_query, top_k_docs)\n",
215
+ " re_ranked_res = sorted(\n",
216
+ " [[doc, meta, score] for doc, meta, score in zip(retriever_res['documents'][0], retriever_res['metadatas'][0], re_rank_scores)],\n",
217
+ " key=lambda x: x[-1],\n",
218
+ " reverse=True,\n",
219
+ " )[:3]\n",
220
+ "\n",
221
+ " # поиск ответа и нормы с помощью LLM\n",
222
+ " return answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type)\n",
223
+ "\n",
224
+ "\n",
225
+ "demo = gr.Interface(\n",
226
+ " fn=fn,\n",
227
+ " inputs=[\n",
228
+ " gr.Textbox(lines=3, label='Запрос', placeholder='Введите запрос'),\n",
229
+ " gr.Dropdown(label='Тип LLM', choices=['ChatGPT', 'GigaChat', 'Без LLM'], value='ChatGPT'),\n",
230
+ " gr.Dropdown(label='Тип итогового ответа LLM', choices=['Только цифры штрафа', 'Развернутый ответ'], value='Только цифры штрафа'),\n",
231
+ " gr.Checkbox(label=\"Использовать LLM для Retriever'а\", value=True, info=\"При использовании LLM для Retriever'а ко входному запросу будет добавляться промежуточный ответ LLM на запрос. Это способствует повышению качества поиска ответа.\"),\n",
232
+ " gr.Checkbox(label=\"Использовать LLM для проверки валидности запроса\", value=False)\n",
233
+ " ],\n",
234
+ " outputs=[\n",
235
+ " gr.Textbox(label='Ответ'),\n",
236
+ " gr.Textbox(label='Норма'),\n",
237
+ " gr.Textbox(label=\"Уверенность Cross-Encoder'а\"),\n",
238
+ " ],\n",
239
+ ")\n",
240
+ "\n",
241
+ "demo.launch()"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "id": "17ebff84-9312-46ae-9972-cd859fad36f9",
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": []
251
+ }
252
+ ],
253
+ "metadata": {
254
+ "kernelspec": {
255
+ "display_name": "Python 3 (ipykernel)",
256
+ "language": "python",
257
+ "name": "python3"
258
+ },
259
+ "language_info": {
260
+ "codemirror_mode": {
261
+ "name": "ipython",
262
+ "version": 3
263
+ },
264
+ "file_extension": ".py",
265
+ "mimetype": "text/x-python",
266
+ "name": "python",
267
+ "nbconvert_exporter": "python",
268
+ "pygments_lexer": "ipython3",
269
+ "version": "3.8.10"
270
+ }
271
+ },
272
+ "nbformat": 4,
273
+ "nbformat_minor": 5
274
+ }
db/97372335-8399-4edb-90ed-bcaf83a48591/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e861bdd7817a442cd756c01729685476c6d03bf7320edfa45a9d31217dc60f89
3
+ size 6424000
db/97372335-8399-4edb-90ed-bcaf83a48591/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:995285888fae05329c795e15bef4387c9003f7ed65af6bd46fc6a1220b9e2265
3
+ size 100
db/97372335-8399-4edb-90ed-bcaf83a48591/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:659f72af1aa9a4dc9bba946c9c62147869d7b1eaef779684a67821dd6ba565bf
3
+ size 52870
db/97372335-8399-4edb-90ed-bcaf83a48591/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:135cb2913a715740c00acafbd2a392317f77955f782a1e2974ded7c7f0d769e6
3
+ size 8000
db/97372335-8399-4edb-90ed-bcaf83a48591/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c44a130f362f4e7bd70389814b9d8d8889d0d1262ca68d6c5fd7ae2fb11e707
3
+ size 16976
db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aef8dfed4c15fb14fd74ce9f26f2aecf739df73b7ae473f2e4a36baa830770d0
3
+ size 6424000
db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91a34e115568431e60861fa63dee2fbc5eba7750376a27655abbd690383acc8a
3
+ size 100
db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e015a6c87c458effc1fb3105fbdb92b3681df738b60f2d64624af959b206afe0
3
+ size 52870
db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd2b7f4046e646fb88c6d837cb3d2067ba3b0d82f256c7f3afe4a5fd4f26a655
3
+ size 8000
db/9dc5e8eb-9345-462f-a1fa-938aec17a9a0/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6209ce61c94144576d5851341ac04d85cf6b558c02da410e983b3f9290e411bc
3
+ size 16976
db/aa5bd886-4fdf-4d66-965d-18ffddafee21/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1e1753e5819ae5dffa99158015f31ba066c12ee0e4f657e64dc64660dd7b77f
3
+ size 6424000
db/aa5bd886-4fdf-4d66-965d-18ffddafee21/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:995285888fae05329c795e15bef4387c9003f7ed65af6bd46fc6a1220b9e2265
3
+ size 100
db/aa5bd886-4fdf-4d66-965d-18ffddafee21/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0086f16ab6934ed3c2a5b99a1e9595e150623259c8396e0faf2f26023f20c757
3
+ size 52870
db/aa5bd886-4fdf-4d66-965d-18ffddafee21/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38c55df27ec0a8adb0ac1e2d55377a84ddbb939c3f3f9704a150952b067108e5
3
+ size 8000
db/aa5bd886-4fdf-4d66-965d-18ffddafee21/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6de5ffb84a9d65860195295c241ed10a8618676c290dfa8379158bef5c308f5e
3
+ size 16976
db/b402ad0c-87fe-4eff-b65e-58777a7face5/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72c643cdc6cdf14f5bf99407c6df138da0ce720c6b2127ee36aea4fe57a3b50
3
+ size 6424000
db/b402ad0c-87fe-4eff-b65e-58777a7face5/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:995285888fae05329c795e15bef4387c9003f7ed65af6bd46fc6a1220b9e2265
3
+ size 100
db/b402ad0c-87fe-4eff-b65e-58777a7face5/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ddf8414b24f52bb480445ed59c48b82cd02324d2668f74f8582e8d16d1558ab
3
+ size 52870
db/b402ad0c-87fe-4eff-b65e-58777a7face5/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:287895f1d9f79674f38d9748c5e8a38edef18d74567d51b7790eb3094d959605
3
+ size 8000
db/b402ad0c-87fe-4eff-b65e-58777a7face5/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c4b1a82fceed2e2b272bffdc76f6dd30062dfa5fcdf3ca1292fd739a764cfdb
3
+ size 16976
db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c29da25687cacaa124ee55421db14f74f971e5e185a029e7113f3d047bac697f
3
+ size 29159424
db/e90dea15-53e7-4b41-b502-0ab854ae08d8/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02d382a284aa159c52564b0fc03080a00b4ecbbbbe383781d81f004a21c7e61a
3
+ size 6424000
db/e90dea15-53e7-4b41-b502-0ab854ae08d8/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:995285888fae05329c795e15bef4387c9003f7ed65af6bd46fc6a1220b9e2265
3
+ size 100
db/e90dea15-53e7-4b41-b502-0ab854ae08d8/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf1f97bd21b051c83e0e8a19d1e4f7de0972d3095b7b1807f94dfeafc99db89
3
+ size 52870
db/e90dea15-53e7-4b41-b502-0ab854ae08d8/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc538b666ab7415eda93fabfb42aa70ec82d57f19509da8eb899e10faa21a153
3
+ size 8000
db/e90dea15-53e7-4b41-b502-0ab854ae08d8/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94b06e796bcc2b6f1959f2575e260cc9ed17950526af440de41f1bb7f148ed0f
3
+ size 16976
docs/КоАП РФ.txt ADDED
The diff for this file is too large to render. See raw diff
 
examples/1.png ADDED
examples/10.png ADDED
examples/11.png ADDED
examples/2.png ADDED
examples/3.png ADDED
examples/4.png ADDED
examples/5.png ADDED
examples/6.png ADDED
examples/7.png ADDED
examples/8.png ADDED
examples/9.png ADDED
generated_additional_llm_answer.csv ADDED
The diff for this file is too large to render. See raw diff
 
generated_questions.csv ADDED
The diff for this file is too large to render. See raw diff
 
preprocess_doc.ipynb ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "bc1b8947-4397-4fb8-b3ae-310fdb44c056",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "__import__('pysqlite3')\n",
11
+ "import sys\n",
12
+ "import os\n",
13
+ "sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')\n",
14
+ "os.environ['ALLOW_RESET'] = 'True'\n",
15
+ "\n",
16
+ "import torch\n",
17
+ "from torch.utils.data import DataLoader\n",
18
+ "from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline\n",
19
+ "import numpy as np\n",
20
+ "from tqdm import tqdm\n",
21
+ "\n",
22
+ "import chromadb"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "c0af096a-f302-4df9-9f26-63213ad44b8f",
28
+ "metadata": {},
29
+ "source": [
30
+ "### Подготавливаем базу данных"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 12,
36
+ "id": "e5a2886c-586a-49f7-8833-0972407bf1fa",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "client = chromadb.PersistentClient(path='db')\n",
41
+ "client.reset()\n",
42
+ "\n",
43
+ "collection = client.create_collection(\n",
44
+ " name=\"administrative_codex\",\n",
45
+ " metadata={\"hnsw:space\": \"cosine\"}\n",
46
+ ")"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "329c4f86-c514-4039-9494-621cb7042f77",
52
+ "metadata": {},
53
+ "source": [
54
+ "### Открываем и предобрабатываем КоАП"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 3,
60
+ "id": "f057b53d-5a50-4a92-bbda-4c8396aca107",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "with open('docs/КоАП РФ.txt', encoding='utf-8') as r:\n",
65
+ " raw_text = r.read().split('\\n\\n')"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "id": "5589508c-5be9-403d-ba68-7a0e60658136",
71
+ "metadata": {},
72
+ "source": [
73
+ "### Делим документ по частям статей, исключаем лишнее"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 9,
79
+ "id": "d7b87b89-5cc2-4c3b-8cee-8ea59b4778a5",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "paragraphs = []\n",
84
+ "index = 0\n",
85
+ "\n",
86
+ "while index != len(raw_text):\n",
87
+ " if raw_text[index].startswith('Статья'):\n",
88
+ " article = ' '.join(raw_text[index].strip().split()[:2])\n",
89
+ " article_points = raw_text[index + 1].split('\\n')\n",
90
+ "\n",
91
+ " cur_point = ''\n",
92
+ " for i in range(len(article_points)):\n",
93
+ " cur_point_part = article_points[i].strip()\n",
94
+ " \n",
95
+ " if 'КонсультантПлюс' in article_points[i] + article_points[i - 1]:\n",
96
+ " continue\n",
97
+ " elif cur_point_part.split()[0].strip().replace('.', '').isnumeric() or cur_point_part.startswith('Примечание. '):\n",
98
+ " if cur_point:\n",
99
+ " if cur_point.startswith('Примечание. '):\n",
100
+ " paragraphs.append([cur_point, article, 'Примечание.'])\n",
101
+ " elif cur_point[0].isnumeric():\n",
102
+ " paragraphs.append([' '.join(cur_point.split()[1:]), article, f'Часть {cur_point.split()[0]}'])\n",
103
+ " else:\n",
104
+ " paragraphs.append([cur_point, article, ''])\n",
105
+ " \n",
106
+ " cur_point = cur_point_part\n",
107
+ " elif cur_point_part[0] != '(' and cur_point_part[-1] != ')' and 'утратил силу' not in cur_point_part[:20].lower():\n",
108
+ " cur_point += ' ' + cur_point_part\n",
109
+ " \n",
110
+ " index += 2\n",
111
+ " else:\n",
112
+ " index += 1"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "id": "5e574c65-8100-45e3-9d8b-e7a15bfec242",
118
+ "metadata": {},
119
+ "source": [
120
+ "### Получаем эмбеддинги из извлеченных фрагментов и сохраняем их в базу данных"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 5,
126
+ "id": "35204261-c2a0-43bc-b167-931154cfb77f",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "checkpoint = 'sentence-transformers/LaBSE'\n",
131
+ "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
132
+ "model = AutoModel.from_pretrained(checkpoint, device_map='cuda:0')"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 6,
138
+ "id": "7aad08e2-fca0-4cf8-b568-17e5c2b4ffff",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "def encode(docs):\n",
143
+ " if type(docs) == str:\n",
144
+ " docs = [docs]\n",
145
+ "\n",
146
+ " encoded_input = tokenizer(\n",
147
+ " docs,\n",
148
+ " padding=True,\n",
149
+ " truncation=True,\n",
150
+ " max_length=512,\n",
151
+ " return_tensors='pt'\n",
152
+ " )\n",
153
+ " \n",
154
+ " with torch.no_grad():\n",
155
+ " model_output = model(**encoded_input.to('cuda'))\n",
156
+ " \n",
157
+ " embeddings = model_output.pooler_output\n",
158
+ " embeddings = torch.nn.functional.normalize(embeddings)\n",
159
+ " return embeddings.detach().cpu().tolist()"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 10,
165
+ "id": "f118299e-c57a-40ce-a060-5a86b452cfec",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "BATCH_SIZE = 128\n",
170
+ "loader = DataLoader(paragraphs, batch_size=BATCH_SIZE)"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 13,
176
+ "id": "95dafd67-f6f8-41e1-a209-6325a23e066f",
177
+ "metadata": {},
178
+ "outputs": [
179
+ {
180
+ "name": "stderr",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:11<00:00, 1.48it/s]\n"
184
+ ]
185
+ }
186
+ ],
187
+ "source": [
188
+ "for i, docs in enumerate(tqdm(loader)):\n",
189
+ " embeddings = encode(docs[0])\n",
190
+ " collection.add(\n",
191
+ " documents=docs[0],\n",
192
+ " metadatas=[{'doc': 'КоАП РФ', 'article': a, 'point': p} for a, p in zip(docs[1], docs[2])],\n",
193
+ " embeddings=embeddings,\n",
194
+ " ids=[f'id{i * BATCH_SIZE + j}' for j in range(len(docs[0]))],\n",
195
+ " )"
196
+ ]
197
+ }
198
+ ],
199
+ "metadata": {
200
+ "kernelspec": {
201
+ "display_name": "Python 3 (ipykernel)",
202
+ "language": "python",
203
+ "name": "python3"
204
+ },
205
+ "language_info": {
206
+ "codemirror_mode": {
207
+ "name": "ipython",
208
+ "version": 3
209
+ },
210
+ "file_extension": ".py",
211
+ "mimetype": "text/x-python",
212
+ "name": "python",
213
+ "nbconvert_exporter": "python",
214
+ "pygments_lexer": "ipython3",
215
+ "version": "3.8.10"
216
+ }
217
+ },
218
+ "nbformat": 4,
219
+ "nbformat_minor": 5
220
+ }
question_generation.ipynb ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 14,
6
+ "id": "d024645c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "__import__('pysqlite3')\n",
11
+ "import sys\n",
12
+ "import os\n",
13
+ "sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')\n",
14
+ "os.environ['ALLOW_RESET'] = 'True'\n",
15
+ "\n",
16
+ "import pandas as pd\n",
17
+ "from tqdm import tqdm\n",
18
+ "import time\n",
19
+ "\n",
20
+ "import chromadb\n",
21
+ "from gigachat import GigaChat\n",
22
+ "\n",
23
+ "client = chromadb.PersistentClient(path='db')\n",
24
+ "collection = client.get_collection(name=\"administrative_codex\")"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 25,
30
+ "id": "17dae6a5",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "docs = collection.get()['documents']\n",
35
+ "prompt = 'Задание: напиши в виде нумерованного списка 3 конкретных независимых друг от друга вопроса, ответ на которые можно найти в приведенном тексте. Не упоминай федеральные законы. Не упоминай КоАП.'"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "id": "91549726-3c7a-44ef-8519-c1afc3adde0f",
41
+ "metadata": {},
42
+ "source": [
43
+ "### Генерируем вопросы к каждому фрагменту текста"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 29,
49
+ "id": "06f82948",
50
+ "metadata": {},
51
+ "outputs": [
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "100%|███████████████████████████████████████████████████████████████████████████████| 2130/2130 [54:44<00:00, 1.54s/it]\n"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "for doc in tqdm(docs[1:]):\n",
62
+ " question_ready = False\n",
63
+ " \n",
64
+ " while not question_ready:\n",
65
+ " try:\n",
66
+ " text = f'{doc}\\n\\n{prompt}'\n",
67
+ " \n",
68
+ " with GigaChat(credentials='N2ZiNDIxZTgtM2Y4Yy00MGJjLWI4OTgtN2M5NGM5MTYzZTNiOmFmYjJmZTUwLTc1OWItNGQ5MC1iMGVmLTMwYTNlODU3YzVmZg==', scope='GIGACHAT_API_PERS', verify_ssl_certs=False) as giga:\n",
69
+ " questions = giga.chat(text).choices[0].message.content\n",
70
+ " \n",
71
+ " question_ready = True\n",
72
+ " except:\n",
73
+ " time.sleep(5)\n",
74
+ "\n",
75
+ " df = pd.read_csv('generated_questions.csv')\n",
76
+ " new_df = pd.DataFrame({'text': [doc], 'questions': [questions]})\n",
77
+ " pd.concat([df, new_df], ignore_index=True).to_csv('generated_questions.csv', index=False)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 31,
83
+ "id": "2f44eac2-7ce0-4d26-9f4a-41f5bfe0fa44",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "generated_questions_df = pd.read_csv('generated_questions.csv')\n",
88
+ "docs = generated_questions_df['text'].tolist()\n",
89
+ "generated_questions = generated_questions_df['questions'].tolist()\n",
90
+ "\n",
91
+ "prompt = 'В России. Дай подробный ответ текстом, похожим на закон, не пиши ничего лишнего.'"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "markdown",
96
+ "id": "90b543f8-0b94-4c0c-8a69-9574b7c54db9",
97
+ "metadata": {},
98
+ "source": [
99
+ "### Генерируем ответы к вопросам, в которых есть слово штраф"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 39,
105
+ "id": "a7937078-6200-44ba-b43f-4867e947b750",
106
+ "metadata": {},
107
+ "outputs": [
108
+ {
109
+ "name": "stderr",
110
+ "output_type": "stream",
111
+ "text": [
112
+ "100%|███████████████████████████████████████████████████████████████████████████████| 1978/1978 [50:30<00:00, 1.53s/it]\n"
113
+ ]
114
+ }
115
+ ],
116
+ "source": [
117
+ "for doc, g_questions in zip(tqdm(docs[153:]), generated_questions[153:]):\n",
118
+ " llm_answer_ready = False\n",
119
+ " fine_question = ''\n",
120
+ "\n",
121
+ " for question in g_questions.split('\\n'):\n",
122
+ " question = ' '.join(question.split()[1:])\n",
123
+ " \n",
124
+ " if 'штраф' in question:\n",
125
+ " fine_question = question\n",
126
+ " break\n",
127
+ "\n",
128
+ " if not fine_question:\n",
129
+ " continue\n",
130
+ " \n",
131
+ " while not llm_answer_ready:\n",
132
+ " try:\n",
133
+ " text = f'Помоги, пожалуйста. {fine_question} {prompt}'\n",
134
+ " \n",
135
+ " with GigaChat(credentials='MmU3OTdhNmItMTQzYy00NGQzLWEyYTctZjcxOWJmYThiMWE5OmE1ZDdhNDkxLWI5ZTEtNGFkZS04N2JjLTExZjE5MTYwNGQ5Yg==', scope='GIGACHAT_API_PERS', verify_ssl_certs=False) as giga:\n",
136
+ " llm_answer = giga.chat(text).choices[0].message.content.split('\\n')[0]\n",
137
+ " \n",
138
+ " llm_answer_ready = True\n",
139
+ " except:\n",
140
+ " time.sleep(5)\n",
141
+ "\n",
142
+ " \n",
143
+ " if len(llm_answer) > 100:\n",
144
+ " df = pd.read_csv('generated_additional_llm_answer.csv')\n",
145
+ " new_df = pd.DataFrame({'text': [doc], 'question': [fine_question], 'llm_answer': [llm_answer]})\n",
146
+ " pd.concat([df, new_df], ignore_index=True).to_csv('generated_additional_llm_answer.csv', index=False)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "5fd64855-01b5-4c66-a425-b6d91b355a22",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": []
156
+ }
157
+ ],
158
+ "metadata": {
159
+ "kernelspec": {
160
+ "display_name": "Python 3 (ipykernel)",
161
+ "language": "python",
162
+ "name": "python3"
163
+ },
164
+ "language_info": {
165
+ "codemirror_mode": {
166
+ "name": "ipython",
167
+ "version": 3
168
+ },
169
+ "file_extension": ".py",
170
+ "mimetype": "text/x-python",
171
+ "name": "python",
172
+ "nbconvert_exporter": "python",
173
+ "pygments_lexer": "ipython3",
174
+ "version": "3.8.10"
175
+ }
176
+ },
177
+ "nbformat": 4,
178
+ "nbformat_minor": 5
179
+ }