VladimirVorobev's picture
Create app.py
96b116a
raw
history blame
9.62 kB
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import chromadb
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from langchain.llms import OpenAI, GigaChat
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
chatgpt = OpenAI(
api_key='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',
)
gigachat = GigaChat(
credentials='Y2Y4Yjk5ODUtNThmMC00ODdjLTk5ODItNDdmYzhmNDdmNzE0OjQ5Y2RjNTVkLWFmMGQtNGJlYy04OGNiLTI1Yzc3MmJkMzYwYw==',
scope='GIGACHAT_API_PERS',
verify_ssl_certs=False
)
llms = {
'ChatGPT': chatgpt,
'GigaChat': gigachat,
}
# задаем формат вывода модели
answer_task_types = {
'Развернутый ответ': 'Ответь достаточно подробно, но не используй ничего лишнего.',
'Только цифры штрафа': 'Ответь в виде <количество> рублей или <диапазон> рублей, и больше ничего не пиши.'
}
# проверяем с помощью LLM валидность запроса, исключая обработку бессмысленного входа
validity_template = '{query}\n\nЭто валидный запрос? Ответь да или нет, больше ничего не пиши.'
validity_prompt = PromptTemplate(template=validity_template, input_variables=['query'])
# получаем ответ модели на запрос, используем его для более качественного поиска Retriever'ом и Cross-Encoder'ом
query_template = '{query} Ответь текстом, похожим на закон, не пиши ничего лишнего. Не используй в ответе слово КоАП РФ. Не используй слово "Россия".'
query_prompt = PromptTemplate(template=query_template, input_variables=['query'])
# просим LLM выбрать один из 3 фрагментов текста, выбранных поисковыми моделями, где по мнению модели есть ответ. Если ответа нет, модель нам об этом сообщает
choose_answer_template = '1. {text_1}\n\n2. {text_2}\n\n3. {text_3}\n\nЗадание: выбери из перечисленных выше отрывков тот, где есть ответ на вопрос: "{query}". В качестве ответа напиши только номер 1, 2 или 3 и все. Если в данных отрывках нет ответа, то напиши "Нет ответа".'
choose_answer_prompt = PromptTemplate(template=choose_answer_template, input_variables=['text_1', 'text_2', 'text_3', 'query'])
# просим LLM ответить на вопрос, опираясь на найденный фрагмент, и в нужном формате, или сообщить, что ответа все-таки нет
answer_template = '{text}\n\nЗадание: ответь на вопрос по тексту: "{query}". {answer_type} Если в данном тексте нет ответа, то напиши "Нет ответа".'
answer_prompt = PromptTemplate(template=answer_template, input_variables=['text', 'query', 'answer_type'])
client = chromadb.PersistentClient(path='db')
collection = client.get_collection(name="administrative_codex")
retriever_checkpoint = 'sentence-transformers/LaBSE'
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_checkpoint)
retriever_model = AutoModel.from_pretrained(retriever_checkpoint)
cross_encoder_checkpoint = 'jeffwan/mmarco-mMiniLMv2-L12-H384-v1'
cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_checkpoint)
cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_checkpoint)
cross_encoder = pipeline('text-classification', model=cross_encoder_model, tokenizer=cross_encoder_tokenizer)
def encode(docs):
if type(docs) == str:
docs = [docs]
encoded_input = retriever_tokenizer(
docs,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
with torch.no_grad():
model_output = retriever_model(**encoded_input)
embeddings = model_output.pooler_output
embeddings = torch.nn.functional.normalize(embeddings)
return embeddings.detach().cpu().tolist()
def re_rank(sentence, docs):
return [res['score'] for res in cross_encoder([{'text': sentence, 'text_pair': doc} for doc in docs], max_length=512, truncation=True)]
def update_query_with_llm(query, llm_type, use_llm_for_retriever):
if llm_type == 'Без LLM' or not use_llm_for_retriever:
return query
llm_chain = LLMChain(prompt=query_prompt, llm=llms[llm_type])
return f'{query} {llm_chain.run(query).strip()}'
def answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type):
if llm_type == 'Без LLM':
answer, metadata, re_ranker_score = re_ranked_res[0]
else:
llm_chain = LLMChain(prompt=choose_answer_prompt, llm=llms[llm_type])
llm_chain_dict = {f'text_{i}': res[0] for i, res in enumerate(re_ranked_res, start=1)}
llm_chain_dict['query'] = query
llm_res = llm_chain.run(llm_chain_dict).strip()
if 'нет ответа' in llm_res.lower() or not llm_res[0].isnumeric():
return 'Нет ответа', '', ''
most_suitable_text, metadata, re_ranker_score = re_ranked_res[int(llm_res[0]) - 1]
llm_chain = LLMChain(prompt=answer_prompt, llm=llms[llm_type])
answer = llm_chain.run({'text': most_suitable_text, 'query': query, 'answer_type': llm_answer_type}).strip()
if 'нет ответа' in answer.lower():
answer = 'Нет ответа'
# если LLM сначала выбрала фрагмент, где есть ответ, а потом не смогла ответить на вопрос (что бывает редко), то все равно порекомендуем пользователю обратиться к норме
law_norm = f"{'Попробуйте обратиться к этому источнику: ' if answer == 'Нет ответа' else ''}{metadata['article']} {metadata['point']} {metadata['doc']}"
return answer, law_norm, re_ranker_score
def check_request_validity(func):
def wrapper(
query,
llm_type,
llm_answer_type,
use_llm_for_retriever,
use_llm_for_request_validation
):
query = query.strip()
if not query:
return 'Невалидный запрос', '', ''
if llm_type == 'Без LLM' or not use_llm_for_request_validation:
return func(query, llm_type, llm_answer_type, use_llm_for_retriever)
llm_chain = LLMChain(prompt=validity_prompt, llm=llms[llm_type])
if 'нет' in llm_chain.run(query).lower():
return 'Невалидный запрос', '', ''
return func(query, llm_type, llm_answer_type, use_llm_for_retriever)
return wrapper
@check_request_validity
def fn(
query,
llm_type,
llm_answer_type,
use_llm_for_retriever
):
# обогатим запрос с помощью LLM, чтобы поисковым моделям было проще найти нужный фрагмент с ответом
retriever_ranker_query = update_query_with_llm(query, llm_type, use_llm_for_retriever)
# Retriever-поиск по базе данных
retriever_res = collection.query(
query_embeddings=encode(retriever_ranker_query),
n_results=10,
)
top_k_docs = retriever_res['documents'][0]
# re-ranking с помощью Cross-Encoder'а и отбор лучших кандидатов
re_rank_scores = re_rank(retriever_ranker_query, top_k_docs)
re_ranked_res = sorted(
[[doc, meta, score] for doc, meta, score in zip(retriever_res['documents'][0], retriever_res['metadatas'][0], re_rank_scores)],
key=lambda x: x[-1],
reverse=True,
)[:3]
# поиск ответа и нормы с помощью LLM
return answer_with_llm(query, re_ranked_res, llm_type, llm_answer_type)
demo = gr.Interface(
fn=fn,
inputs=[
gr.Textbox(lines=3, label='Запрос', placeholder='Введите запрос'),
gr.Dropdown(label='Тип LLM', choices=['ChatGPT', 'GigaChat', 'Без LLM'], value='ChatGPT'),
gr.Dropdown(label='Тип итогового ответа LLM', choices=['Только цифры штрафа', 'Развернутый ответ'], value='Только цифры штрафа'),
gr.Checkbox(label="Использовать LLM для Retriever'а", value=True, info="При использовании LLM для Retriever'а ко входному запросу будет добавляться промежуточный ответ LLM на запрос. Это способствует повышению качества поиска ответа."),
gr.Checkbox(label="Использовать LLM для проверки валидности запроса", value=False)
],
outputs=[
gr.Textbox(label='Ответ'),
gr.Textbox(label='Норма'),
gr.Textbox(label="Уверенность Cross-Encoder'а"),
],
)
demo.launch()