{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bc1b8947-4397-4fb8-b3ae-310fdb44c056", "metadata": {}, "outputs": [], "source": [ "__import__('pysqlite3')\n", "import sys\n", "import os\n", "sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')\n", "os.environ['ALLOW_RESET'] = 'True'\n", "\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import chromadb" ] }, { "cell_type": "markdown", "id": "c0af096a-f302-4df9-9f26-63213ad44b8f", "metadata": {}, "source": [ "### Подготавливаем базу данных" ] }, { "cell_type": "code", "execution_count": 12, "id": "e5a2886c-586a-49f7-8833-0972407bf1fa", "metadata": {}, "outputs": [], "source": [ "client = chromadb.PersistentClient(path='db')\n", "client.reset()\n", "\n", "collection = client.create_collection(\n", " name=\"administrative_codex\",\n", " metadata={\"hnsw:space\": \"cosine\"}\n", ")" ] }, { "cell_type": "markdown", "id": "329c4f86-c514-4039-9494-621cb7042f77", "metadata": {}, "source": [ "### Открываем и предобрабатываем КоАП" ] }, { "cell_type": "code", "execution_count": 3, "id": "f057b53d-5a50-4a92-bbda-4c8396aca107", "metadata": {}, "outputs": [], "source": [ "with open('docs/КоАП РФ.txt', encoding='utf-8') as r:\n", " raw_text = r.read().split('\\n\\n')" ] }, { "cell_type": "markdown", "id": "5589508c-5be9-403d-ba68-7a0e60658136", "metadata": {}, "source": [ "### Делим документ по частям статей, исключаем лишнее" ] }, { "cell_type": "code", "execution_count": 9, "id": "d7b87b89-5cc2-4c3b-8cee-8ea59b4778a5", "metadata": {}, "outputs": [], "source": [ "paragraphs = []\n", "index = 0\n", "\n", "while index != len(raw_text):\n", " if raw_text[index].startswith('Статья'):\n", " article = ' '.join(raw_text[index].strip().split()[:2])\n", " article_points = raw_text[index + 1].split('\\n')\n", "\n", " cur_point = ''\n", " for i in range(len(article_points)):\n", " cur_point_part = article_points[i].strip()\n", " \n", " if 'КонсультантПлюс' in article_points[i] + article_points[i - 1]:\n", " continue\n", " elif cur_point_part.split()[0].strip().replace('.', '').isnumeric() or cur_point_part.startswith('Примечание. '):\n", " if cur_point:\n", " if cur_point.startswith('Примечание. '):\n", " paragraphs.append([cur_point, article, 'Примечание.'])\n", " elif cur_point[0].isnumeric():\n", " paragraphs.append([' '.join(cur_point.split()[1:]), article, f'Часть {cur_point.split()[0]}'])\n", " else:\n", " paragraphs.append([cur_point, article, ''])\n", " \n", " cur_point = cur_point_part\n", " elif cur_point_part[0] != '(' and cur_point_part[-1] != ')' and 'утратил силу' not in cur_point_part[:20].lower():\n", " cur_point += ' ' + cur_point_part\n", " \n", " index += 2\n", " else:\n", " index += 1" ] }, { "cell_type": "markdown", "id": "5e574c65-8100-45e3-9d8b-e7a15bfec242", "metadata": {}, "source": [ "### Получаем эмбеддинги из извлеченных фрагментов и сохраняем их в базу данных" ] }, { "cell_type": "code", "execution_count": 5, "id": "35204261-c2a0-43bc-b167-931154cfb77f", "metadata": {}, "outputs": [], "source": [ "checkpoint = 'sentence-transformers/LaBSE'\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "model = AutoModel.from_pretrained(checkpoint, device_map='cuda:0')" ] }, { "cell_type": "code", "execution_count": 6, "id": "7aad08e2-fca0-4cf8-b568-17e5c2b4ffff", "metadata": {}, "outputs": [], "source": [ "def encode(docs):\n", " if type(docs) == str:\n", " docs = [docs]\n", "\n", " encoded_input = tokenizer(\n", " docs,\n", " padding=True,\n", " truncation=True,\n", " max_length=512,\n", " return_tensors='pt'\n", " )\n", " \n", " with torch.no_grad():\n", " model_output = model(**encoded_input.to('cuda'))\n", " \n", " embeddings = model_output.pooler_output\n", " embeddings = torch.nn.functional.normalize(embeddings)\n", " return embeddings.detach().cpu().tolist()" ] }, { "cell_type": "code", "execution_count": 10, "id": "f118299e-c57a-40ce-a060-5a86b452cfec", "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "loader = DataLoader(paragraphs, batch_size=BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": 13, "id": "95dafd67-f6f8-41e1-a209-6325a23e066f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:11<00:00, 1.48it/s]\n" ] } ], "source": [ "for i, docs in enumerate(tqdm(loader)):\n", " embeddings = encode(docs[0])\n", " collection.add(\n", " documents=docs[0],\n", " metadatas=[{'doc': 'КоАП РФ', 'article': a, 'point': p} for a, p in zip(docs[1], docs[2])],\n", " embeddings=embeddings,\n", " ids=[f'id{i * BATCH_SIZE + j}' for j in range(len(docs[0]))],\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }