text2sql / app.py
omeryentur's picture
Update app.py
62c8779 verified
from llama_cpp import Llama
from typing import Optional, Dict, Union
from huggingface_hub import hf_hub_download
import gradio as gr
import time
# Download the model from Hugging Face
model_path = hf_hub_download(
repo_id="omeryentur/phi-3-sql",
filename="phi-3-sql.Q4_K_M.gguf",
use_auth_token=True
)
# Initialize the Llama model
llm = Llama(
model_path=model_path,
n_ctx=512,
n_threads=1,
)
def generate_sql_query(text_input_schema: str, text_input_question: str):
try:
# Construct the prompt for the model
prompt = f"""
<|system|>
{text_input_schema}
<|user|>
{text_input_question}
<|sql|>"""
# Generate SQL query
completion = llm(
prompt,
max_tokens=512,
temperature=0,
stop=["<end_of_turn>"]
)
# Extract and return the generated SQL query
generated = completion['choices'][0]['text'].strip()
return {"sql_query":generated}
except Exception as e:
return {"error": str(e)}
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Sql Query")
with gr.Row():
with gr.Column():
text_input_schema = gr.TextArea(label="Schema")
text_input_question = gr.Textbox(label="question")
generate_btn = gr.Button("Create Sql Query")
with gr.Row():
with gr.Column():
output = gr.JSON(label="Sql Query:")
generate_btn.click(
fn=generate_sql_query,
inputs=[text_input_schema, text_input_question],
outputs=[output]
)
if __name__ == "__main__":
demo.launch()