Spaces:
Sleeping
Sleeping
| import logging | |
| from pydantic import BaseModel, Field | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import JSONLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from prompt import prompt | |
| from utils import execute_query_and_return_df | |
| from constant import ( | |
| GEMINI_MODEL, | |
| GOOGLE_API_KEY, | |
| PATH_SCHEMA, | |
| PATH_DB, | |
| EMBEDDING_MODEL | |
| ) | |
| class SQLOutput(BaseModel): | |
| query: str = Field(description="The SQL query to run.") | |
| reasoning: str = Field(description="Reasoning to understand the SQL query.") | |
| class Text2SQLRAG: | |
| def __init__(self, | |
| path_schema: str = PATH_SCHEMA, | |
| path_db: str = PATH_DB, | |
| model: str = GEMINI_MODEL, | |
| api_key: str = GOOGLE_API_KEY, | |
| embedding_model: str = EMBEDDING_MODEL | |
| ): | |
| """ | |
| A class for generating SQL queries based on natural language text. | |
| """ | |
| self.logger = logging.getLogger(__name__) | |
| self.logger.info('Initializing Text2SQLRAG') | |
| model_kwargs = { | |
| "max_tokens": 512, | |
| "temperature": 0.2, | |
| "top_k": 250, | |
| "top_p": 1, | |
| "stop_sequences": ["\n\nHuman:"] | |
| } | |
| self.schema = path_schema | |
| self.db = path_db | |
| self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model) | |
| self.model = ChatGoogleGenerativeAI( | |
| model=model, | |
| api_key=api_key, | |
| model_kwargs=model_kwargs | |
| ) | |
| self.llm = self.model.with_structured_output(SQLOutput) | |
| self.retriever = self._indexing_vectore() | |
| def _indexing_vectore(self): | |
| """ | |
| Indexes the database schema using a vector store for efficient retrieval. | |
| This method loads the schema from a JSON file, splits it into chunks, | |
| embeds the chunks using a specified embedding model, and stores them in | |
| a vector store. It returns a retriever configured to search for the top | |
| k relevant documents. | |
| Returns: | |
| retriever: An object capable of retrieving the most relevant schema | |
| chunks based on the given search parameters. | |
| """ | |
| self.logger.info('Indexing schema') | |
| db_schema_loader = JSONLoader( | |
| file_path=self.schema, | |
| jq_schema='.', | |
| text_content=False | |
| ) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| separators=["separator"], | |
| chunk_size=10000, | |
| chunk_overlap=100 | |
| ) | |
| docs = text_splitter.split_documents(db_schema_loader.load()) | |
| vectorstore = Chroma.from_documents(documents=docs, | |
| embedding=self.embeddings) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
| self.logger.info('Finished indexing schema') | |
| return retriever | |
| def run(self, question: str): | |
| self.logger.info(f'Running Text2SQLRAG for question: {question}') | |
| rag_chain = ( | |
| {"context": self.retriever, "question": RunnablePassthrough()} | |
| | prompt | |
| | self.llm | |
| ) | |
| response = rag_chain.invoke(question) | |
| return response | |