from embeddings import KorRobertaEmbeddings import streamlit as st from streamlit import session_state as sst from langchain_core.runnables import ( RunnablePassthrough, RunnableParallel, ) PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] def create_or_get_pinecone_index(index_name: str, dimension: int = 768): from pinecone import Pinecone, ServerlessSpec client = Pinecone(api_key=PINECONE_API_KEY) if index_name in [index["name"] for index in client.list_indexes()]: pc_index = client.Index(index_name) print("☑️ Got the existing Pinecone index") else: client.create_index( name=index_name, dimension=dimension, metric="cosine", spec=ServerlessSpec("aws", "us-west-2"), ) pc_index = client.Index(index_name) print("☑️ Created a new Pinecone index") print(pc_index.describe_index_stats()) return pc_index def get_pinecone_vectorstore( index_name: str, embedding_fn=KorRobertaEmbeddings(), dimension: int = 768, namespace: str = None, ): from langchain_pinecone import Pinecone index = create_or_get_pinecone_index( index_name, dimension, ) vs = Pinecone( index, embedding_fn, pinecone_api_key=PINECONE_API_KEY, index_name=index_name, namespace=namespace, ) print(vs) return vs def build_pinecone_retrieval_chain(vectorstore): retriever = vectorstore.as_retriever() rag_chain_with_source = RunnableParallel( {"context": retriever, "question": RunnablePassthrough()} ) return rag_chain_with_source @st.cache_resource def get_pinecone_retrieval_chain(collection_name): print("☑️ Building a new pinecone retrieval chain...") embed_fn = KorRobertaEmbeddings() pinecone_vectorstore = get_pinecone_vectorstore( index_name=collection_name, embedding_fn=embed_fn, dimension=768, namespace="0221", ) chain = build_pinecone_retrieval_chain(pinecone_vectorstore) return chain def rerun(): st.rerun() st.title("이노션 데모") with st.spinner("환경 설정 중"): sst.retrieval_chain = get_pinecone_retrieval_chain( collection_name="innocean", ) if prompt := st.chat_input("정보 검색"): # Display user message in chat message container with st.chat_message("human"): st.markdown(prompt) # Get assistant response outputs = sst.retrieval_chain.invoke(prompt) print(outputs) retrieval_docs = outputs["context"] # Display assistant response in chat message container with st.chat_message("assistant"): st.markdown(retrieval_docs[0].metadata["answer"]) with st.expander("출처 보기", expanded=True): st.info(f"출처 페이지: {retrieval_docs[0].metadata['page']}") st.markdown(retrieval_docs[0].metadata["source_passage"]) # tabs = st.tabs([f"doc{i}" for i in range(len(retrieval_docs))]) # for i in range(len(retrieval_docs)): # tabs[i].write(retrieval_docs[i].page_content) # tabs[i].write(retrieval_docs[i].metadata)