|
__import__('pysqlite3') |
|
import sys |
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') |
|
import chromadb |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline |
|
|
|
from langchain.llms import OpenAI, GigaChat |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
|
|
chatgpt = OpenAI( |
|
api_key='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', |
|
) |
|
|
|
gigachat = GigaChat( |
|
credentials='Y2Y4Yjk5ODUtNThmMC00ODdjLTk5ODItNDdmYzhmNDdmNzE0OjQ5Y2RjNTVkLWFmMGQtNGJlYy04OGNiLTI1Yzc3MmJkMzYwYw==', |
|
scope='GIGACHAT_API_PERS', |
|
verify_ssl_certs=False |
|
) |
|
|
|
llms = { |
|
'ChatGPT': chatgpt, |
|
'GigaChat': gigachat, |
|
} |
|
|
|
|
|
answer_task_types = { |
|
'Развернутый ответ': 'Ответь достаточно подробно, но не используй ничего лишнего.', |
|
'Только цифры штрафа': 'Ответь в виде <количество> рублей или <диапазон> рублей, и больше ничего не пиши.' |
|
} |
|
|
|
|
|
validity_template = '{query}\n\nЭто валидный запрос? Ответь да или нет, больше ничего не пиши.' |
|
validity_prompt = PromptTemplate(template=validity_template, input_variables=['query']) |
|
|
|
|
|
query_template = '{query} Ответь текстом, похожим на закон, не пиши ничего лишнего. Не используй в ответе слово КоАП РФ. Не используй слово "Россия".' |
|
query_prompt = PromptTemplate(template=query_template, input_variables=['query']) |
|
|
|
|
|
choose_answer_template = '1. {text_1}\n\n2. {text_2}\n\n3. {text_3}\n\nЗадание: выбери из перечисленных выше отрывков тот, где есть ответ на вопрос: "{query}". В качестве ответа напиши только номер 1, 2 или 3 и все. Если в данных отрывках нет ответа, то напиши "Нет ответа".' |
|
choose_answer_prompt = PromptTemplate(template=choose_answer_template, input_variables=['text_1', 'text_2', 'text_3', 'query']) |
|
|
|
|
|
answer_template = '{text}\n\nЗадание: ответь на вопрос по тексту: "{query}". {answer_type} Если в данном тексте нет ответа, то напиши "Нет ответа".' |
|
answer_prompt = PromptTemplate(template=answer_template, input_variables=['text', 'query', 'answer_type']) |
|
|
|
client = chromadb.PersistentClient(path='db') |
|
collection = client.get_collection(name="administrative_codex") |
|
|
|
retriever_checkpoint = 'sentence-transformers/LaBSE' |
|
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_checkpoint) |
|
retriever_model = AutoModel.from_pretrained(retriever_checkpoint) |
|
|
|
cross_encoder_checkpoint = 'jeffwan/mmarco-mMiniLMv2-L12-H384-v1' |
|
cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_checkpoint) |
|
cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_checkpoint) |
|
cross_encoder = pipeline('text-classification', model=cross_encoder_model, tokenizer=cross_encoder_tokenizer) |
|
|
|
|
|
def encode(docs): |
|
if type(docs) == str: |
|
docs = [docs] |
|
|
|
encoded_input = retriever_tokenizer( |
|
docs, |
|
padding=True, |
|
truncation=True, |
|
max_length=512, |
|
return_tensors='pt' |
|
) |
|
|
|
with torch.no_grad(): |
|
model_output = retriever_model(**encoded_input) |
|
|
|
embeddings = model_output.pooler_output |
|
embeddings = torch.nn.functional.normalize(embeddings) |
|
return embeddings.detach().cpu().tolist() |
|
|
|
|
|
def re_rank(sentence, docs): |
|
return [res['score'] for res in cross_encoder([{'text': sentence, 'text_pair': doc} for doc in docs], max_length=512, truncation=True)] |
|
|
|
|
|
def update_query_with_llm(query, llm_type, use_llm_for_retriever): |
|
if llm_type == 'Без LLM' or not use_llm_for_retriever: |
|
return query |
|
|
|
llm_chain = LLMChain(prompt=query_prompt, llm=llms[llm_type]) |
|
return f'{query} {llm_chain.run(query).strip()}' |
|
|
|
|
|
def answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type): |
|
if llm_type == 'Без LLM': |
|
answer, metadata, re_ranker_score = re_ranked_res[0] |
|
else: |
|
llm_chain = LLMChain(prompt=choose_answer_prompt, llm=llms[llm_type]) |
|
llm_chain_dict = {f'text_{i}': res[0] for i, res in enumerate(re_ranked_res, start=1)} |
|
llm_chain_dict['query'] = query |
|
|
|
llm_res = llm_chain.run(llm_chain_dict).strip() |
|
|
|
if 'нет ответа' in llm_res.lower() or not llm_res[0].isnumeric(): |
|
return 'Нет ответа', '', '' |
|
|
|
most_suitable_text, metadata, re_ranker_score = re_ranked_res[int(llm_res[0]) - 1] |
|
|
|
llm_chain = LLMChain(prompt=answer_prompt, llm=llms[llm_type]) |
|
answer = llm_chain.run({'text': most_suitable_text, 'query': query, 'answer_type': llm_answer_type}).strip() |
|
|
|
if 'нет ответа' in answer.lower(): |
|
answer = 'Нет ответа' |
|
|
|
|
|
law_norm = f"{'Попробуйте обратиться к этому источнику: ' if answer == 'Нет ответа' else ''}{metadata['article']} {metadata['point']} {metadata['doc']}" |
|
return answer, law_norm, re_ranker_score |
|
|
|
|
|
def check_request_validity(func): |
|
def wrapper( |
|
query, |
|
llm_type, |
|
llm_answer_type, |
|
use_llm_for_retriever, |
|
use_llm_for_request_validation |
|
): |
|
query = query.strip() |
|
|
|
if not query: |
|
return 'Невалидный запрос', '', '' |
|
|
|
if llm_type == 'Без LLM' or not use_llm_for_request_validation: |
|
return func(query, llm_type, llm_answer_type, use_llm_for_retriever) |
|
|
|
llm_chain = LLMChain(prompt=validity_prompt, llm=llms[llm_type]) |
|
|
|
if 'нет' in llm_chain.run(query).lower(): |
|
return 'Невалидный запрос', '', '' |
|
|
|
return func(query, llm_type, llm_answer_type, use_llm_for_retriever) |
|
|
|
return wrapper |
|
|
|
|
|
@check_request_validity |
|
def fn( |
|
query, |
|
llm_type, |
|
llm_answer_type, |
|
use_llm_for_retriever |
|
): |
|
|
|
retriever_ranker_query = update_query_with_llm(query, llm_type, use_llm_for_retriever) |
|
|
|
|
|
retriever_res = collection.query( |
|
query_embeddings=encode(retriever_ranker_query), |
|
n_results=10, |
|
) |
|
|
|
top_k_docs = retriever_res['documents'][0] |
|
|
|
|
|
re_rank_scores = re_rank(retriever_ranker_query, top_k_docs) |
|
re_ranked_res = sorted( |
|
[[doc, meta, score] for doc, meta, score in zip(retriever_res['documents'][0], retriever_res['metadatas'][0], re_rank_scores)], |
|
key=lambda x: x[-1], |
|
reverse=True, |
|
)[:3] |
|
|
|
|
|
return answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=fn, |
|
inputs=[ |
|
gr.Textbox(lines=3, label='Запрос', placeholder='Введите запрос'), |
|
gr.Dropdown(label='Тип LLM', choices=['ChatGPT', 'GigaChat', 'Без LLM'], value='ChatGPT'), |
|
gr.Dropdown(label='Тип итогового ответа LLM', choices=['Только цифры штрафа', 'Развернутый ответ'], value='Только цифры штрафа'), |
|
gr.Checkbox(label="Использовать LLM для Retriever'а", value=True, info="При использовании LLM для Retriever'а ко входному запросу будет добавляться промежуточный ответ LLM на запрос. Это способствует повышению качества поиска ответа."), |
|
gr.Checkbox(label="Использовать LLM для проверки валидности запроса", value=False) |
|
], |
|
outputs=[ |
|
gr.Textbox(label='Ответ'), |
|
gr.Textbox(label='Норма'), |
|
gr.Textbox(label="Уверенность Cross-Encoder'а"), |
|
], |
|
) |
|
|
|
demo.launch() |
|
|