Spaces:
Sleeping
Sleeping
import logging | |
from pydantic import BaseModel, Field | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_chroma import Chroma | |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import JSONLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from prompt import prompt | |
from utils import execute_query_and_return_df | |
from constant import ( | |
GEMINI_MODEL, | |
GOOGLE_API_KEY, | |
PATH_SCHEMA, | |
PATH_DB, | |
EMBEDDING_MODEL | |
) | |
class SQLOutput(BaseModel): | |
query: str = Field(description="The SQL query to run.") | |
reasoning: str = Field(description="Reasoning to understand the SQL query.") | |
class Text2SQLRAG: | |
def __init__(self, | |
path_schema: str = PATH_SCHEMA, | |
path_db: str = PATH_DB, | |
model: str = GEMINI_MODEL, | |
api_key: str = GOOGLE_API_KEY, | |
embedding_model: str = EMBEDDING_MODEL | |
): | |
""" | |
A class for generating SQL queries based on natural language text. | |
""" | |
self.logger = logging.getLogger(__name__) | |
self.logger.info('Initializing Text2SQLRAG') | |
model_kwargs = { | |
"max_tokens": 512, | |
"temperature": 0.2, | |
"top_k": 250, | |
"top_p": 1, | |
"stop_sequences": ["\n\nHuman:"] | |
} | |
self.schema = path_schema | |
self.db = path_db | |
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model) | |
self.model = ChatGoogleGenerativeAI( | |
model=model, | |
api_key=api_key, | |
model_kwargs=model_kwargs | |
) | |
self.llm = self.model.with_structured_output(SQLOutput) | |
self.retriever = self._indexing_vectore() | |
def _indexing_vectore(self): | |
""" | |
Indexes the database schema using a vector store for efficient retrieval. | |
This method loads the schema from a JSON file, splits it into chunks, | |
embeds the chunks using a specified embedding model, and stores them in | |
a vector store. It returns a retriever configured to search for the top | |
k relevant documents. | |
Returns: | |
retriever: An object capable of retrieving the most relevant schema | |
chunks based on the given search parameters. | |
""" | |
self.logger.info('Indexing schema') | |
db_schema_loader = JSONLoader( | |
file_path=self.schema, | |
jq_schema='.', | |
text_content=False | |
) | |
text_splitter = RecursiveCharacterTextSplitter( | |
separators=["separator"], | |
chunk_size=10000, | |
chunk_overlap=100 | |
) | |
docs = text_splitter.split_documents(db_schema_loader.load()) | |
vectorstore = Chroma.from_documents(documents=docs, | |
embedding=self.embeddings) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
self.logger.info('Finished indexing schema') | |
return retriever | |
def run(self, question: str): | |
self.logger.info(f'Running Text2SQLRAG for question: {question}') | |
rag_chain = ( | |
{"context": self.retriever, "question": RunnablePassthrough()} | |
| prompt | |
| self.llm | |
) | |
response = rag_chain.invoke(question) | |
return response | |