BOE / app.py
JoThanos
add sentencepiece
d6a952e
import os
import torch
import gradio as gr
from textwrap import fill
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain import PromptTemplate
from langchain import HuggingFacePipeline
from langchain.vectorstores import Chroma
from langchain.schema import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import UnstructuredMarkdownLoader, UnstructuredURLLoader
from langchain.chains import LLMChain, SimpleSequentialChain, RetrievalQA, ConversationalRetrievalChain
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
import warnings
from huggingface_hub import login
warnings.filterwarnings('ignore')
# Ensure Hugging Face token is set in the environment variables
huggingface_token = os.getenv('huggingface_token')
login(huggingface_token)
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
quantization_config=quantization_config
)
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
generation_config.max_new_tokens = 1024
generation_config.temperature = 0.0001
generation_config.top_p = 0.95
generation_config.do_sample = True
generation_config.repetition_penalty = 1.15
llm = HuggingFacePipeline(pipeline=pipeline)
embeddings = HuggingFaceEmbeddings(model_name = EMBEDDING_MODEL)
urls = [
"https://www.boe.es/diario_boe/txt.php?id=BOE-A-2024-9523"
]
loader = UnstructuredURLLoader(urls=urls)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts_chunks = text_splitter.split_documents(documents)
db = Chroma.from_documents(texts_chunks, embeddings, persist_directory="db")
template = """Act as an lawyer assistant manager expert. Use the following information to answer the question at the end.
'You must always answer in Spanish' If you do not know the answer reply with 'I am sorry, I dont have enough information'.
Chat History
{chat_history}
Follow Up Input: {question}
Standalone question:
"""
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(template)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
llm_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=db.as_retriever(search_kwargs={"k": 2}),
memory=memory,
condense_question_prompt=CUSTOM_QUESTION_PROMPT,
)
def querying(query, history):
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=False)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=db.as_retriever(search_kwargs={"k": 2}),
memory=memory,
condense_question_prompt=CUSTOM_QUESTION_PROMPT,
)
result = qa_chain({"question": query})
return result["answer"].strip()
iface = gr.ChatInterface(
fn = querying,
chatbot=gr.Chatbot(height=600),
textbox=gr.Textbox(placeholder="Cuantos segmentos hay y en que consisten?", container=False, scale=7),
title="LawyerBot",
theme="soft",
examples=["¿Cuantos segmentos hay?",
"¿Que importe del bono digital corresponde a cada uno de los 5 segmentos?",
"¿Cuál es el importe de la ayuda para el segmento III en canto a dispositivo hardware?",
"Si tengo una microempresa de 2 empleado, ¿qué importe del bono digital me corresponde?",
"¿Qué nuevos segmentos de beneficiarios se han introducido?"],
cache_examples=True,
retry_btn="Repetir",
undo_btn="Deshacer",
clear_btn="Borrar",
submit_btn="Enviar"
)
iface.launch(share=True)