Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
import pandas as pd | |
from dotenv import load_dotenv | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.retrievers import EnsembleRetriever | |
from langchain.storage import LocalFileStore | |
from langchain_anthropic import ChatAnthropic | |
from langchain_community.chat_models import ChatOllama | |
from langchain_community.document_loaders import ( | |
NotebookLoader, | |
TextLoader, | |
DataFrameLoader, | |
) | |
from langchain_community.document_loaders.generic import GenericLoader | |
from langchain_community.document_loaders.parsers.language.language_parser import ( | |
LanguageParser, | |
) | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores import FAISS, Chroma | |
from langchain_core.callbacks.manager import CallbackManager | |
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import ( | |
ConfigurableField, | |
RunnablePassthrough, | |
RunnableLambda, | |
) | |
from langchain_google_genai import GoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
from langchain_cohere import CohereRerank | |
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
from langchain_community.document_transformers import LongContextReorder | |
# Load environment variables | |
load_dotenv() | |
# Repository directories | |
repo_root_dir = "./docs/langchain" | |
repo_dirs = [ | |
"libs/core/langchain_core", | |
"libs/community/langchain_community", | |
"libs/experimental/langchain_experimental", | |
"libs/partners", | |
"libs/cookbook", | |
] | |
repo_dirs = [os.path.join(repo_root_dir, repo) for repo in repo_dirs] | |
# Load Python documents | |
py_documents = [] | |
for path in repo_dirs: | |
py_loader = GenericLoader.from_filesystem( | |
path, | |
glob="**/*", | |
suffixes=[".py"], | |
parser=LanguageParser(language=Language.PYTHON, parser_threshold=30), | |
) | |
py_documents.extend(py_loader.load()) | |
print(f".py ํ์ผ์ ๊ฐ์: {len(py_documents)}") | |
# Load Markdown documents | |
mdx_documents = [] | |
for dirpath, _, filenames in os.walk(repo_root_dir): | |
for file in filenames: | |
if file.endswith(".mdx") and "*venv/" not in dirpath: | |
try: | |
mdx_loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8") | |
mdx_documents.extend(mdx_loader.load()) | |
except Exception: | |
pass | |
print(f".mdx ํ์ผ์ ๊ฐ์: {len(mdx_documents)}") | |
# Load Jupyter Notebook documents | |
ipynb_documents = [] | |
for dirpath, _, filenames in os.walk(repo_root_dir): | |
for file in filenames: | |
if file.endswith(".ipynb") and "*venv/" not in dirpath: | |
try: | |
ipynb_loader = NotebookLoader( | |
os.path.join(dirpath, file), | |
include_outputs=True, | |
max_output_length=20, | |
remove_newline=True, | |
) | |
ipynb_documents.extend(ipynb_loader.load()) | |
except Exception: | |
pass | |
print(f".ipynb ํ์ผ์ ๊ฐ์: {len(ipynb_documents)}") | |
## wikidocs | |
df = pd.read_parquet("./docs/wikidocs_14314.parquet") | |
loader = DataFrameLoader(df, page_content_column="content") | |
wiki_documents = loader.load() | |
# Split documents into chunks | |
def split_documents(documents, language, chunk_size=2000, chunk_overlap=200): | |
splitter = RecursiveCharacterTextSplitter.from_language( | |
language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
return splitter.split_documents(documents) | |
py_docs = split_documents(py_documents, Language.PYTHON) | |
mdx_docs = split_documents(mdx_documents, Language.MARKDOWN) | |
ipynb_docs = split_documents(ipynb_documents, Language.PYTHON) | |
wiki_docs = split_documents(wiki_documents, Language.MARKDOWN) | |
print(f"๋ถํ ๋ .py ๋ฌธ์์ ๊ฐ์: {len(py_docs)}") | |
print(f"๋ถํ ๋ .mdx ๋ฌธ์์ ๊ฐ์: {len(mdx_docs)}") | |
print(f"๋ถํ ๋ .ipynb ๋ฌธ์์ ๊ฐ์: {len(ipynb_docs)}") | |
print(f"๋ถํ ๋ wiki ๋ฌธ์์ ๊ฐ์: {len(wiki_docs)}") | |
combined_documents = py_docs + mdx_docs + ipynb_docs + wiki_docs | |
print(f"์ด ๋ํ๋จผํธ ๊ฐ์: {len(combined_documents)}") | |
# Define the device setting function | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda:0" | |
elif torch.backends.mps.is_available(): | |
return "mps" | |
else: | |
return "cpu" | |
# Use the function to set the device in model_kwargs | |
device = get_device() | |
# Initialize embeddings and cache | |
store = LocalFileStore("~/.cache/embedding") | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name="BAAI/bge-m3", | |
model_kwargs={"device": device}, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
embeddings, store, namespace=embeddings.model_name | |
) | |
# Create and save FAISS index | |
FAISS_DB_INDEX = "./langchain_faiss" | |
if not os.path.exists(FAISS_DB_INDEX): | |
faiss_db = FAISS.from_documents( | |
documents=combined_documents, | |
embedding=cached_embeddings, | |
) | |
faiss_db.save_local(folder_path=FAISS_DB_INDEX) | |
# Create and save Chroma index | |
CHROMA_DB_INDEX = "./langchain_chroma" | |
if not os.path.exists(CHROMA_DB_INDEX): | |
chroma_db = Chroma.from_documents( | |
documents=combined_documents, | |
embedding=cached_embeddings, | |
persist_directory=CHROMA_DB_INDEX, | |
) | |
# load vectorstore | |
faiss_db = FAISS.load_local( | |
FAISS_DB_INDEX, cached_embeddings, allow_dangerous_deserialization=True | |
) | |
chroma_db = Chroma( | |
embedding_function=cached_embeddings, | |
persist_directory=CHROMA_DB_INDEX, | |
) | |
# Create retrievers | |
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) | |
chroma_retriever = chroma_db.as_retriever( | |
search_type="similarity", search_kwargs={"k": 10} | |
) | |
bm25_retriever = BM25Retriever.from_documents(combined_documents) | |
bm25_retriever.k = 10 | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[ | |
bm25_retriever.with_config(run_name="bm25"), | |
faiss_retriever.with_config(run_name="faiss"), | |
chroma_retriever.with_config(run_name="chroma"), | |
], | |
weights=[0.4, 0.3, 0.3], | |
) | |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=10) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=ensemble_retriever, | |
) | |
# Create prompt template | |
prompt = PromptTemplate.from_template( | |
"""๋น์ ์ 20๋ ์ฐจ AI ๊ฐ๋ฐ์์ ๋๋ค. ๋น์ ์ ์๋ฌด๋ ์ฃผ์ด์ง ์ง๋ฌธ์ ๋ํด ์ฃผ์ด์ง ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ต๋ํ ํ์ฉํ์ฌ ๋ต๋ณํ๋ ๊ฒ์ ๋๋ค. | |
๋ฌธ์๋ Python ์ฝ๋์ ๋ํ ์ ๋ณด๋ฅผ ํฌํจํ๊ณ ์์ผ๋ฏ๋ก, ๋ต๋ณ ์์ฑ ์ Python ์ฝ๋ ์ค๋ํซ๊ณผ ๊ตฌ์ฒด์ ์ธ ์ค๋ช ์ ํฌํจํด ์ฃผ์ธ์. | |
๋ต๋ณ์ ๊ฐ๋ฅํ ํ ์์ธํ๊ณ ๋ช ํํ๊ฒ ์์ฑํ๋ฉฐ, ์ดํดํ๊ธฐ ์ฌ์ด ํ๊ธ๋ก ์์ฑํด ์ฃผ์ธ์. | |
ํ์ฌ ์ฃผ์ด์ง ๋ฌธ์์์ ๋ต๋ณ์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ, "ํ์ฌ ์ ๊ณต๋ ์ง๋ฌธ๋ง์ผ๋ก๋ ์ ํํ ๋ต๋ณ์ ๋๋ฆฌ๊ธฐ ์ด๋ ค์์. ์ถ๊ฐ ์ ๋ณด๋ฅผ ์ฃผ์๋ฉด ๋ ๋์์ ๋๋ฆด ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค. ์ธ์ ๋ ์ง ์ง๋ฌธํด ์ฃผ์ธ์!"๋ผ๊ณ ๋ต๋ณํด ์ฃผ์ธ์. | |
๊ฐ ๋ต๋ณ์ ์ถ์ฒ(source)๋ฅผ ๋ฐ๋์ ํ๊ธฐํด ์ฃผ์ธ์. | |
# ์ฐธ๊ณ ๋ฌธ์: | |
{context} | |
# ์ง๋ฌธ: | |
{question} | |
# ๋ต๋ณ: | |
์ถ์ฒ: | |
- source1 | |
- source2 | |
- ... | |
""" | |
) | |
# Define callback handler for streaming | |
class StreamCallback(BaseCallbackHandler): | |
def on_llm_new_token(self, token: str, **kwargs): | |
print(token, end="", flush=True) | |
streaming = os.getenv("STREAMING", "true") == "true" | |
print("STREAMING", streaming) | |
# Initialize LLMs with configuration | |
llm = ChatOpenAI( | |
model="gpt-4o", | |
temperature=0, | |
streaming=streaming, | |
callbacks=[StreamCallback()], | |
).configurable_alternatives( | |
ConfigurableField(id="llm"), | |
default_key="gpt4", | |
claude=ChatAnthropic( | |
model="claude-3-opus-20240229", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
gpt3=ChatOpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
gemini=GoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
llama3=ChatGroq( | |
model_name="llama3-70b-8192", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
ollama=ChatOllama( | |
model="EEVE-Korean-10.8B:long", | |
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), | |
), | |
) | |
# Create retrieval-augmented generation chain | |
rag_chain = ( | |
{ | |
"context": compression_retriever | |
| RunnableLambda(LongContextReorder().transform_documents), | |
"question": RunnablePassthrough(), | |
} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
model_key = os.getenv("MODEL_KEY", "gemini") | |
print("MODEL_KEY", model_key) | |
def respond_stream( | |
message, | |
history: list[tuple[str, str]], | |
): | |
response = "" | |
for chunk in rag_chain.with_config(configurable={"llm": model_key}).stream(message): | |
response += chunk | |
yield response | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
): | |
return rag_chain.with_config(configurable={"llm": model_key}).invoke(message) | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond_stream if streaming else respond, | |
title="๋ญ์ฒด์ธ์ ๋ํด์ ๋ฌผ์ด๋ณด์ธ์!", | |
description="์๋ ํ์ธ์!\n์ ๋ ๋ญ์ฒด์ธ์ ๋ํ ์ธ๊ณต์ง๋ฅ QA๋ด์ ๋๋ค. ๋ญ์ฒด์ธ์ ๋ํด ๊น์ ์ง์์ ๊ฐ์ง๊ณ ์์ด์. ๋ญ์ฒด์ธ ๊ฐ๋ฐ์ ๊ดํ ๋์์ด ํ์ํ์๋ฉด ์ธ์ ๋ ์ง ์ง๋ฌธํด์ฃผ์ธ์!", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |