File size: 3,933 Bytes
53e0a4a
dea6e8d
 
 
bfc91a3
dea6e8d
bfc91a3
11cb4f2
 
 
 
 
a82575a
f51ae5a
a82575a
 
 
21403be
4f4d5be
 
 
 
 
 
 
 
 
 
 
 
 
0561f54
 
 
 
 
 
 
 
a6db4fa
0561f54
 
70e3b22
0561f54
a6db4fa
0561f54
 
 
 
 
 
a6db4fa
0561f54
41c21fe
0561f54
a6db4fa
0561f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f4d5be
0561f54
 
 
 
 
4f4d5be
a6db4fa
4f4d5be
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
import getpass
import os

os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']  # agregada en la config de hugginface
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = st.secrets['OPENAI_API_KEY']

from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import HuggingFaceDatasetLoader
from langchain_community.embeddings import HuggingFaceEmbeddings

# Initialization
if 'chain' not in st.session_state:
    st.session_state['chain'] = 'dummy'

def get_data():
    return st.session_state["chain"]


def add_data(chain):
    st.session_state["chain"]= chain

chain = get_data()
if chain == 'dummy':    
    #Carga de DATASET
    dataset_name = "Waflon/FAQ"
    page_content_column = "respuestas"
    loader = HuggingFaceDatasetLoader(dataset_name, page_content_column)
    data = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    #Transformado a tipo de dato especifico para esto
    docs = text_splitter.split_documents(data)

    #Modelo QA sentence similarity
    modelPath = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' #español
    model_kwargs = {'device':'cpu'} # cuda or cpu
    encode_kwargs = {'normalize_embeddings': False}

    #Embeddings que transforman a vectores densos multidimensionales las preguntas del SII
    embeddings = HuggingFaceEmbeddings(
        model_name=modelPath,     # Ruta a modelo Pre entrenado
        model_kwargs=model_kwargs, # Opciones de configuracion del modelo
        encode_kwargs=encode_kwargs # Opciones de Encoding
    )

    #DB y retriever
    db = FAISS.from_documents(docs, embeddings)  # Create a retriever object from the 'db' with a search configuration where it retrieves up to 4 relevant splits/documents.
    retriever = db.as_retriever(search_kwargs={"k": 3})

    prompt_template = """Usa los siguientes fragmentos de contextos para responder una pregunta al final. Por favor sigue las siguientes reglas:
    1. Si la pregunta requiere vinculos, por favor retornar solamente las vinculos de los vinculos sin respuesta
    2. Si no sabes la respuesta, no inventes una respuesta. Solamente di **No pude encontrar la respuesta definitiva, pero tal vez quieras ver los siguientes vinculos** y agregalos a la lista de vinculos.
    3. Si encuentras la respuesta, escribe una respuesta concisa y agrega la lista de vinculos que sean usadas **directamente** para derivar la respuesta. Excluye los vinculos que sean irrelevantes al final de la respuesta
    
    {contexto}
    
    Pregunta: {question}
    Respuesta Util:"""
    
    
    QA_CHAIN_PROMPT = PromptTemplate.from_template(prompt_template) # prompt_template defined above
    llm_chain = LLMChain(llm=ChatOpenAI(), prompt=QA_CHAIN_PROMPT, callbacks=None, verbose=True)
    document_prompt = PromptTemplate(
        input_variables=["page_content", "url"],
        template="Contexto:\n{page_content}\nVinculo: {url}",
    )
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=llm_chain,
        document_variable_name="contexto",
        document_prompt=document_prompt,
        callbacks=None,
    )
    chain = RetrievalQA(
        combine_documents_chain=combine_documents_chain,
        callbacks=None,
        verbose=True,
        retriever=retriever,
    )
    add_data(chain)

pregunta = st.text_area('Ingresa algun texto:', value="Que es un APA?")
tmp_button = st.button("CLICK")
if tmp_button: #Esperar al boton
    out = chain.invoke(pregunta)
    st.write(out)
    #st.rerun() #Restart app
else:
    st.stop()