anpigon's picture
chore: Update device for HuggingFaceBgeEmbeddings to dynamic device selection
c2c8656
raw
history blame
8.07 kB
import os
import torch
import gradio as gr
from dotenv import load_dotenv
from langchain.callbacks.base import BaseCallbackHandler
from langchain.embeddings import CacheBackedEmbeddings
from langchain_community.retrievers import BM25Retriever, EnsembleRetriever
from langchain.storage import LocalFileStore
from langchain_anthropic import ChatAnthropic
from langchain_community.chat_models import ChatOllama
from langchain_community.document_loaders import NotebookLoader, TextLoader
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers.language.language_parser import (
LanguageParser,
)
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import ConfigurableField, RunnablePassthrough
from langchain_google_genai import GoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
# Load environment variables
load_dotenv()
# Repository directories
repo_root_dir = "./docs/langchain"
repo_dirs = [
"libs/core/langchain_core",
"libs/community/langchain_community",
"libs/experimental/langchain_experimental",
"libs/partners",
"libs/cookbook",
]
repo_dirs = [os.path.join(repo_root_dir, repo) for repo in repo_dirs]
# Load Python documents
py_documents = []
for path in repo_dirs:
py_loader = GenericLoader.from_filesystem(
path,
glob="**/*",
suffixes=[".py"],
parser=LanguageParser(language=Language.PYTHON, parser_threshold=30),
)
py_documents.extend(py_loader.load())
print(f".py 파일의 개수: {len(py_documents)}")
# Load Markdown documents
mdx_documents = []
for dirpath, _, filenames in os.walk(repo_root_dir):
for file in filenames:
if file.endswith(".mdx") and "*venv/" not in dirpath:
try:
mdx_loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8")
mdx_documents.extend(mdx_loader.load())
except Exception:
pass
print(f".mdx 파일의 개수: {len(mdx_documents)}")
# Load Jupyter Notebook documents
ipynb_documents = []
for dirpath, _, filenames in os.walk(repo_root_dir):
for file in filenames:
if file.endswith(".ipynb") and "*venv/" not in dirpath:
try:
ipynb_loader = NotebookLoader(
os.path.join(dirpath, file),
include_outputs=True,
max_output_length=20,
remove_newline=True,
)
ipynb_documents.extend(ipynb_loader.load())
except Exception:
pass
print(f".ipynb 파일의 개수: {len(ipynb_documents)}")
# Split documents into chunks
def split_documents(documents, language, chunk_size=2000, chunk_overlap=200):
splitter = RecursiveCharacterTextSplitter.from_language(
language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
return splitter.split_documents(documents)
py_docs = split_documents(py_documents, Language.PYTHON)
mdx_docs = split_documents(mdx_documents, Language.MARKDOWN)
ipynb_docs = split_documents(ipynb_documents, Language.PYTHON)
print(f"λΆ„ν• λœ .py 파일의 개수: {len(py_docs)}")
print(f"λΆ„ν• λœ .mdx 파일의 개수: {len(mdx_docs)}")
print(f"λΆ„ν• λœ .ipynb 파일의 개수: {len(ipynb_docs)}")
combined_documents = py_docs + mdx_docs + ipynb_docs
print(f"총 λ„νλ¨ΌνŠΈ 개수: {len(combined_documents)}")
# Define the device setting function
def get_device():
if torch.cuda.is_available():
return "cuda:0"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
# Use the function to set the device in model_kwargs
device = get_device()
# Initialize embeddings and cache
store = LocalFileStore("~/.cache/embedding")
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-m3",
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True},
)
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
embeddings, store, namespace=embeddings.model_name
)
# Create and save FAISS index
FAISS_DB_INDEX = "./langchain_faiss"
# db = FAISS.from_documents(combined_documents, cached_embeddings)
# db.save_local(folder_path=FAISS_DB_INDEX)
db = FAISS.load_local(
FAISS_DB_INDEX, cached_embeddings, allow_dangerous_deserialization=True
)
# Create retrievers
faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
bm25_retriever = BM25Retriever.from_documents(combined_documents)
bm25_retriever.k = 10
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5], search_type="mmr"
)
# Create prompt template
prompt = PromptTemplate.from_template(
"""당신은 20λ…„μ°¨ AI κ°œλ°œμžμž…λ‹ˆλ‹€. λ‹Ήμ‹ μ˜ μž„λ¬΄λŠ” 주어진 μ§ˆλ¬Έμ— λŒ€ν•˜μ—¬ μ΅œλŒ€ν•œ λ¬Έμ„œμ˜ 정보λ₯Ό ν™œμš©ν•˜μ—¬ λ‹΅λ³€ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€.
λ¬Έμ„œλŠ” Python μ½”λ“œμ— λŒ€ν•œ 정보λ₯Ό λ‹΄κ³  μžˆμŠ΅λ‹ˆλ‹€. λ”°λΌμ„œ, 닡변을 μž‘μ„±ν•  λ•Œμ—λŠ” Python μ½”λ“œμ— λŒ€ν•œ μƒμ„Έν•œ code snippet을 ν¬ν•¨ν•˜μ—¬ μž‘μ„±ν•΄μ£Όμ„Έμš”.
μ΅œλŒ€ν•œ μžμ„Έν•˜κ²Œ λ‹΅λ³€ν•˜κ³ , ν•œκΈ€λ‘œ λ‹΅λ³€ν•΄ μ£Όμ„Έμš”. 주어진 λ¬Έμ„œμ—μ„œ 닡변을 찾을 수 μ—†λŠ” 경우, "λ¬Έμ„œμ— 닡변이 μ—†μŠ΅λ‹ˆλ‹€."라고 λ‹΅λ³€ν•΄ μ£Όμ„Έμš”.
닡변은 좜처(source)λ₯Ό λ°˜λ“œμ‹œ ν‘œκΈ°ν•΄ μ£Όμ„Έμš”.
#μ°Έκ³ λ¬Έμ„œ:
{context}
#질문:
{question}
#λ‹΅λ³€:
좜처:
- source1
- source2
- ...
"""
)
# Define callback handler for streaming
class StreamCallback(BaseCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs):
print(token, end="", flush=True)
# Initialize LLMs with configuration
llm = ChatOpenAI(
model="gpt-4o",
temperature=0,
streaming=True,
callbacks=[StreamCallback()],
).configurable_alternatives(
ConfigurableField(id="llm"),
default_key="gpt4",
claude=ChatAnthropic(
model="claude-3-opus-20240229",
temperature=0,
streaming=True,
callbacks=[StreamCallback()],
),
gpt3=ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0,
streaming=True,
callbacks=[StreamCallback()],
),
gemini=GoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0,
streaming=True,
callbacks=[StreamCallback()],
),
llama3=ChatGroq(
model_name="llama3-70b-8192",
temperature=0,
streaming=True,
callbacks=[StreamCallback()],
),
ollama=ChatOllama(
model="EEVE-Korean-10.8B:long",
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
),
)
# Create retrieval-augmented generation chain
rag_chain = (
{"context": ensemble_retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
model_key = os.getenv("LLM_MODEL", "gpt4")
print("model", model_key)
def respond(
message,
history: list[tuple[str, str]],
):
response = ""
for chunk in rag_chain.with_config(configurable={"llm": model_key}).stream(message):
response += chunk
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
title="λž­μ²΄μΈμ— λŒ€ν•΄μ„œ λ¬Όμ–΄λ³΄μ„Έμš”!",
description="μ•ˆλ…•ν•˜μ„Έμš”!\nμ €λŠ” λž­μ²΄μΈμ— λŒ€ν•œ 인곡지λŠ₯ QAλ΄‡μž…λ‹ˆλ‹€. λž­μ²΄μΈμ— λŒ€ν•΄ κΉŠμ€ 지식을 가지고 μžˆμ–΄μš”. 랭체인 κ°œλ°œμ— κ΄€ν•œ 도움이 ν•„μš”ν•˜μ‹œλ©΄ μ–Έμ œλ“ μ§€ μ§ˆλ¬Έν•΄μ£Όμ„Έμš”!",
)
if __name__ == "__main__":
demo.launch()