Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from dotenv import load_dotenv | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.retrievers import EnsembleRetriever | |
from langchain.storage import LocalFileStore | |
from langchain_anthropic import ChatAnthropic | |
from langchain_community.chat_models import ChatOllama | |
from langchain_community.document_loaders import NotebookLoader, TextLoader | |
from langchain_community.document_loaders.generic import GenericLoader | |
from langchain_community.document_loaders.parsers.language.language_parser import ( | |
LanguageParser, | |
) | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores import FAISS, Chroma | |
from langchain_core.callbacks.manager import CallbackManager | |
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import ConfigurableField, RunnablePassthrough | |
from langchain_google_genai import GoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
from langchain_cohere import CohereRerank | |
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
# Load environment variables | |
load_dotenv() | |
# Repository directories | |
repo_root_dir = "./docs/langchain" | |
repo_dirs = [ | |
"libs/core/langchain_core", | |
"libs/community/langchain_community", | |
"libs/experimental/langchain_experimental", | |
"libs/partners", | |
"libs/cookbook", | |
] | |
repo_dirs = [os.path.join(repo_root_dir, repo) for repo in repo_dirs] | |
# Load Python documents | |
py_documents = [] | |
for path in repo_dirs: | |
py_loader = GenericLoader.from_filesystem( | |
path, | |
glob="**/*", | |
suffixes=[".py"], | |
parser=LanguageParser(language=Language.PYTHON, parser_threshold=30), | |
) | |
py_documents.extend(py_loader.load()) | |
print(f".py νμΌμ κ°μ: {len(py_documents)}") | |
# Load Markdown documents | |
mdx_documents = [] | |
for dirpath, _, filenames in os.walk(repo_root_dir): | |
for file in filenames: | |
if file.endswith(".mdx") and "*venv/" not in dirpath: | |
try: | |
mdx_loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8") | |
mdx_documents.extend(mdx_loader.load()) | |
except Exception: | |
pass | |
print(f".mdx νμΌμ κ°μ: {len(mdx_documents)}") | |
# Load Jupyter Notebook documents | |
ipynb_documents = [] | |
for dirpath, _, filenames in os.walk(repo_root_dir): | |
for file in filenames: | |
if file.endswith(".ipynb") and "*venv/" not in dirpath: | |
try: | |
ipynb_loader = NotebookLoader( | |
os.path.join(dirpath, file), | |
include_outputs=True, | |
max_output_length=20, | |
remove_newline=True, | |
) | |
ipynb_documents.extend(ipynb_loader.load()) | |
except Exception: | |
pass | |
print(f".ipynb νμΌμ κ°μ: {len(ipynb_documents)}") | |
# Split documents into chunks | |
def split_documents(documents, language, chunk_size=2000, chunk_overlap=200): | |
splitter = RecursiveCharacterTextSplitter.from_language( | |
language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
return splitter.split_documents(documents) | |
py_docs = split_documents(py_documents, Language.PYTHON) | |
mdx_docs = split_documents(mdx_documents, Language.MARKDOWN) | |
ipynb_docs = split_documents(ipynb_documents, Language.PYTHON) | |
print(f"λΆν λ .py νμΌμ κ°μ: {len(py_docs)}") | |
print(f"λΆν λ .mdx νμΌμ κ°μ: {len(mdx_docs)}") | |
print(f"λΆν λ .ipynb νμΌμ κ°μ: {len(ipynb_docs)}") | |
combined_documents = py_docs + mdx_docs + ipynb_docs | |
print(f"μ΄ λνλ¨ΌνΈ κ°μ: {len(combined_documents)}") | |
# Define the device setting function | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda:0" | |
elif torch.backends.mps.is_available(): | |
return "mps" | |
else: | |
return "cpu" | |
# Use the function to set the device in model_kwargs | |
device = get_device() | |
# Initialize embeddings and cache | |
store = LocalFileStore("~/.cache/embedding") | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name="BAAI/bge-m3", | |
model_kwargs={"device": device}, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
embeddings, store, namespace=embeddings.model_name | |
) | |
# Create and save FAISS index | |
FAISS_DB_INDEX = "./langchain_faiss" | |
# faiss_db = FAISS.from_documents( | |
# documents=combined_documents, | |
# embedding=cached_embeddings, | |
# ) | |
# faiss_db.save_local(folder_path=FAISS_DB_INDEX) | |
# Create and save Chroma index | |
CHROMA_DB_INDEX = "./langchain_chroma" | |
# chroma_db = Chroma.from_documents( | |
# documents=combined_documents, | |
# embedding=cached_embeddings, | |
# persist_directory=CHROMA_DB_INDEX, | |
# ) | |
# load vectorstore | |
faiss_db = FAISS.load_local( | |
FAISS_DB_INDEX, cached_embeddings, allow_dangerous_deserialization=True | |
) | |
chroma_db = Chroma( | |
embedding_function=cached_embeddings, | |
persist_directory=CHROMA_DB_INDEX, | |
) | |
# Create retrievers | |
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) | |
chroma_retriever = chroma_db.as_retriever( | |
search_type="similarity", search_kwargs={"k": 10} | |
) | |
bm25_retriever = BM25Retriever.from_documents(combined_documents) | |
bm25_retriever.k = 10 | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, faiss_retriever, chroma_retriever], | |
weights=[0.4, 0.3, 0.3], | |
) | |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=10) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=ensemble_retriever, | |
) | |
# Create prompt template | |
prompt = PromptTemplate.from_template( | |
"""λΉμ μ 20λ μ°¨ AI κ°λ°μμ λλ€. λΉμ μ μ무λ μ£Όμ΄μ§ μ§λ¬Έμ λνμ¬ μ΅λν λ¬Έμμ μ 보λ₯Ό νμ©νμ¬ λ΅λ³νλ κ²μ λλ€. | |
λ¬Έμλ Python μ½λμ λν μ 보λ₯Ό λ΄κ³ μμ΅λλ€. λ°λΌμ, λ΅λ³μ μμ±ν λμλ Python μ½λμ λν μμΈν code snippetμ ν¬ν¨νμ¬ μμ±ν΄μ£ΌμΈμ. | |
μ΅λν μμΈνκ² λ΅λ³νκ³ , νκΈλ‘ λ΅λ³ν΄ μ£ΌμΈμ. μ£Όμ΄μ§ λ¬Έμμμ λ΅λ³μ μ°Ύμ μ μλ κ²½μ°, "λ¬Έμμ λ΅λ³μ΄ μμ΅λλ€."λΌκ³ λ΅λ³ν΄ μ£ΌμΈμ. | |
λ΅λ³μ μΆμ²(source)λ₯Ό λ°λμ νκΈ°ν΄ μ£ΌμΈμ. | |
#μ°Έκ³ λ¬Έμ: | |
{context} | |
#μ§λ¬Έ: | |
{question} | |
#λ΅λ³: | |
μΆμ²: | |
- source1 | |
- source2 | |
- ... | |
""" | |
) | |
# Define callback handler for streaming | |
class StreamCallback(BaseCallbackHandler): | |
def on_llm_new_token(self, token: str, **kwargs): | |
print(token, end="", flush=True) | |
streaming = os.getenv("STREAMING", "true") == "true" | |
print("STREAMING", streaming) | |
# Initialize LLMs with configuration | |
llm = ChatOpenAI( | |
model="gpt-4o", | |
temperature=0, | |
streaming=streaming, | |
callbacks=[StreamCallback()], | |
).configurable_alternatives( | |
ConfigurableField(id="llm"), | |
default_key="gpt4", | |
claude=ChatAnthropic( | |
model="claude-3-opus-20240229", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
gpt3=ChatOpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
gemini=GoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
llama3=ChatGroq( | |
model_name="llama3-70b-8192", | |
temperature=0, | |
streaming=True, | |
callbacks=[StreamCallback()], | |
), | |
ollama=ChatOllama( | |
model="EEVE-Korean-10.8B:long", | |
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), | |
), | |
) | |
# Create retrieval-augmented generation chain | |
rag_chain = ( | |
{"context": compression_retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
model_key = os.getenv("MODEL_KEY", "gemini") | |
print("MODEL_KEY", model_key) | |
def respond_stream( | |
message, | |
history: list[tuple[str, str]], | |
): | |
response = "" | |
for chunk in rag_chain.with_config(configurable={"llm": model_key}).stream(message): | |
response += chunk | |
yield response | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
): | |
return rag_chain.with_config(configurable={"llm": model_key}).invoke(message) | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond_stream if streaming else respond, | |
title="λ체μΈμ λν΄μ λ¬Όμ΄λ³΄μΈμ!", | |
description="μλ νμΈμ!\nμ λ λ체μΈμ λν μΈκ³΅μ§λ₯ QAλ΄μ λλ€. λ체μΈμ λν΄ κΉμ μ§μμ κ°μ§κ³ μμ΄μ. λμ²΄μΈ κ°λ°μ κ΄ν λμμ΄ νμνμλ©΄ μΈμ λ μ§ μ§λ¬Έν΄μ£ΌμΈμ!", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |