__import__('pysqlite3') import os import sys sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') import chromadb import torch import gradio as gr from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline from langchain.llms import OpenAI, GigaChat from langchain.chains import LLMChain from langchain.prompts import PromptTemplate chatgpt = OpenAI( api_key=os.getenv('OPEN_AI_API_KEY'), ) gigachat = GigaChat( credentials=os.getenv('GIGACHAT_API_KEY'), scope='GIGACHAT_API_PERS', verify_ssl_certs=False ) llms = { 'ChatGPT': chatgpt, 'GigaChat': gigachat, } # задаем формат вывода модели answer_task_types = { 'Развернутый ответ': 'Ответь достаточно подробно, но не используй ничего лишнего.', 'Только цифры штрафа': 'Ответь в виде <количество> рублей или <диапазон> рублей, и больше ничего не пиши.' } # проверяем с помощью LLM валидность запроса, исключая обработку бессмысленного входа validity_template = '{query}\n\nЭто валидный запрос? Ответь да или нет, больше ничего не пиши.' validity_prompt = PromptTemplate(template=validity_template, input_variables=['query']) # получаем ответ модели на запрос, используем его для более качественного поиска Retriever'ом и Cross-Encoder'ом query_template = '{query} Ответь текстом, похожим на закон, не пиши ничего лишнего. Не используй в ответе слово КоАП РФ. Не используй слово "Россия".' query_prompt = PromptTemplate(template=query_template, input_variables=['query']) # просим LLM выбрать один из 3 фрагментов текста, выбранных поисковыми моделями, где по мнению модели есть ответ. Если ответа нет, модель нам об этом сообщает choose_answer_template = '1. {text_1}\n\n2. {text_2}\n\n3. {text_3}\n\nЗадание: выбери из перечисленных выше отрывков тот, где есть ответ на вопрос: "{query}". В качестве ответа напиши только номер 1, 2 или 3 и все. Если в данных отрывках нет ответа, то напиши "Нет ответа".' choose_answer_prompt = PromptTemplate(template=choose_answer_template, input_variables=['text_1', 'text_2', 'text_3', 'query']) # просим LLM ответить на вопрос, опираясь на найденный фрагмент, и в нужном формате, или сообщить, что ответа все-таки нет answer_template = '{text}\n\nЗадание: ответь на вопрос по тексту: "{query}". {answer_type} Если в данном тексте нет ответа, то напиши "Нет ответа".' answer_prompt = PromptTemplate(template=answer_template, input_variables=['text', 'query', 'answer_type']) client = chromadb.PersistentClient(path='db') collection = client.get_collection(name="administrative_codex") retriever_checkpoint = 'sentence-transformers/LaBSE' retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_checkpoint) retriever_model = AutoModel.from_pretrained(retriever_checkpoint) cross_encoder_checkpoint = 'jeffwan/mmarco-mMiniLMv2-L12-H384-v1' cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_checkpoint) cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_checkpoint) cross_encoder = pipeline('text-classification', model=cross_encoder_model, tokenizer=cross_encoder_tokenizer) def update_llm_api_key(api_key, llm_type): if llm_type == 'ChatGPT': chatgpt = OpenAI( api_key=api_key, ) llms['ChatGPT'] = chatgpt else: gigachat = GigaChat( credentials=api_key, scope='GIGACHAT_API_PERS', verify_ssl_certs=False ) llms['GigaChat'] = gigachat def encode(docs): if type(docs) == str: docs = [docs] encoded_input = retriever_tokenizer( docs, padding=True, truncation=True, max_length=512, return_tensors='pt' ) with torch.no_grad(): model_output = retriever_model(**encoded_input) embeddings = model_output.pooler_output embeddings = torch.nn.functional.normalize(embeddings) return embeddings.detach().cpu().tolist() def re_rank(sentence, docs): return [res['score'] for res in cross_encoder([{'text': sentence, 'text_pair': doc} for doc in docs], max_length=512, truncation=True)] def update_query_with_llm(query, llm_type, use_llm_for_retriever): if llm_type == 'Без LLM' or not use_llm_for_retriever: return query llm_chain = LLMChain(prompt=query_prompt, llm=llms[llm_type]) return f'{query} {llm_chain.run(query).strip()}' def answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type): if llm_type == 'Без LLM': answer, metadata, re_ranker_score = re_ranked_res[0] else: llm_chain = LLMChain(prompt=choose_answer_prompt, llm=llms[llm_type]) llm_chain_dict = {f'text_{i}': res[0] for i, res in enumerate(re_ranked_res, start=1)} llm_chain_dict['query'] = query llm_res = llm_chain.run(llm_chain_dict).strip() if 'нет ответа' in llm_res.lower() or not llm_res[0].isnumeric(): return 'Нет ответа', '', '' most_suitable_text, metadata, re_ranker_score = re_ranked_res[int(llm_res[0]) - 1] llm_chain = LLMChain(prompt=answer_prompt, llm=llms[llm_type]) answer = llm_chain.run({'text': most_suitable_text, 'query': query, 'answer_type': llm_answer_type}).strip() if 'нет ответа' in answer.lower(): answer = 'Нет ответа' # если LLM сначала выбрала фрагмент, где есть ответ, а потом не смогла ответить на вопрос (что бывает редко), то все равно порекомендуем пользователю обратиться к норме law_norm = f"{'Попробуйте обратиться к этому источнику: ' if answer == 'Нет ответа' else ''}{metadata['article']} {metadata['point']} {metadata['doc']}" return answer, law_norm, re_ranker_score def check_request_validity(func): def wrapper( chatgpt_key, gigachat_key, query, llm_type, llm_answer_type, use_llm_for_retriever, use_llm_for_request_validation ): query = query.strip() if not query: return 'Невалидный запрос', '', '' if llm_type == 'Без LLM' or not use_llm_for_request_validation: return func(chatgpt_key, gigachat_key, query, llm_type, llm_answer_type, use_llm_for_retriever) llm_chain = LLMChain(prompt=validity_prompt, llm=llms[llm_type]) if 'нет' in llm_chain.run(query).lower(): return 'Невалидный запрос', '', '' return func(chatgpt_key, gigachat_key, query, llm_type, llm_answer_type, use_llm_for_retriever) return wrapper @check_request_validity def fn( chatgpt_key, gigachat_key, query, llm_type, llm_answer_type, use_llm_for_retriever ): chatgpt_key = chatgpt_key.strip() gigachat_key = gigachat_key.strip() if chatgpt_key: update_llm_api_key(chatgpt_key, 'ChatGPT') if gigachat_key: update_llm_api_key(gigachat_key, 'GigaChat') # обогатим запрос с помощью LLM, чтобы поисковым моделям было проще найти нужный фрагмент с ответом retriever_ranker_query = update_query_with_llm(query, llm_type, use_llm_for_retriever) # Retriever-поиск по базе данных retriever_res = collection.query( query_embeddings=encode(retriever_ranker_query), n_results=10, ) top_k_docs = retriever_res['documents'][0] # re-ranking с помощью Cross-Encoder'а и отбор лучших кандидатов re_rank_scores = re_rank(retriever_ranker_query, top_k_docs) re_ranked_res = sorted( [[doc, meta, score] for doc, meta, score in zip(retriever_res['documents'][0], retriever_res['metadatas'][0], re_rank_scores)], key=lambda x: x[-1], reverse=True, )[:3] # поиск ответа и нормы с помощью LLM return answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type) demo = gr.Interface( fn=fn, inputs=[ gr.Textbox(lines=1, label='Ключ Open AI API', placeholder='Введите ключ'), gr.Textbox(lines=1, label='Ключ GigaChat API', placeholder='Введите ключ'), gr.Textbox(lines=3, label='Запрос', placeholder='Введите запрос'), gr.Dropdown(label='Тип LLM', choices=['ChatGPT', 'GigaChat', 'Без LLM'], value='ChatGPT'), gr.Dropdown(label='Тип итогового ответа LLM', choices=['Только цифры штрафа', 'Развернутый ответ'], value='Только цифры штрафа'), gr.Checkbox(label="Использовать LLM для Retriever'а", value=True, info="При использовании LLM для Retriever'а ко входному запросу будет добавляться промежуточный ответ LLM на запрос. Это способствует повышению качества поиска ответа."), gr.Checkbox(label="Использовать LLM для проверки валидности запроса", value=False) ], outputs=[ gr.Textbox(label='Ответ'), gr.Textbox(label='Норма'), gr.Textbox(label="Уверенность Cross-Encoder'а"), ], ) demo.launch()