text2sql-rag / lib.py
fahmiaziz's picture
Upload 9 files
af733da verified
raw
history blame
3.48 kB
import logging
from pydantic import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnablePassthrough
from langchain_chroma import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from prompt import prompt
from utils import execute_query_and_return_df
from constant import (
GEMINI_MODEL,
GOOGLE_API_KEY,
PATH_SCHEMA,
PATH_DB,
EMBEDDING_MODEL
)
class SQLOutput(BaseModel):
query: str = Field(description="The SQL query to run.")
reasoning: str = Field(description="Reasoning to understand the SQL query.")
class Text2SQLRAG:
def __init__(self,
path_schema: str = PATH_SCHEMA,
path_db: str = PATH_DB,
model: str = GEMINI_MODEL,
api_key: str = GOOGLE_API_KEY,
embedding_model: str = EMBEDDING_MODEL
):
"""
A class for generating SQL queries based on natural language text.
"""
self.logger = logging.getLogger(__name__)
self.logger.info('Initializing Text2SQLRAG')
model_kwargs = {
"max_tokens": 512,
"temperature": 0.2,
"top_k": 250,
"top_p": 1,
"stop_sequences": ["\n\nHuman:"]
}
self.schema = path_schema
self.db = path_db
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
self.model = ChatGoogleGenerativeAI(
model=model,
api_key=api_key,
model_kwargs=model_kwargs
)
self.llm = self.model.with_structured_output(SQLOutput)
self.retriever = self._indexing_vectore()
def _indexing_vectore(self):
"""
Indexes the database schema using a vector store for efficient retrieval.
This method loads the schema from a JSON file, splits it into chunks,
embeds the chunks using a specified embedding model, and stores them in
a vector store. It returns a retriever configured to search for the top
k relevant documents.
Returns:
retriever: An object capable of retrieving the most relevant schema
chunks based on the given search parameters.
"""
self.logger.info('Indexing schema')
db_schema_loader = JSONLoader(
file_path=self.schema,
jq_schema='.',
text_content=False
)
text_splitter = RecursiveCharacterTextSplitter(
separators=["separator"],
chunk_size=10000,
chunk_overlap=100
)
docs = text_splitter.split_documents(db_schema_loader.load())
vectorstore = Chroma.from_documents(documents=docs,
embedding=self.embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
self.logger.info('Finished indexing schema')
return retriever
def run(self, question: str):
self.logger.info(f'Running Text2SQLRAG for question: {question}')
rag_chain = (
{"context": self.retriever, "question": RunnablePassthrough()}
| prompt
| self.llm
)
response = rag_chain.invoke(question)
return response