Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| # insert current directory to sys.path | |
| sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) | |
| import re | |
| import sqlite3 | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import requests | |
| from googletrans import Translator | |
| from langdetect import detect | |
| from sql_formatter.core import format_sql | |
| translator = Translator() | |
| st.set_page_config( | |
| layout="wide", | |
| page_title="Text To SQL", | |
| page_icon="π", | |
| ) | |
| # TEXT_2_SQL_API = "http://83.219.197.235:40172/api/text2sql/ask" | |
| TEXT_2_SQL_API = os.environ.get( | |
| "TEXT_2_SQL_API", "http://213.181.122.2:40057/api/text2sql/ask" | |
| ) | |
| try: | |
| os.remove("resources/ai_app.db") | |
| except: | |
| pass | |
| def load_database(): | |
| db_conn = sqlite3.connect("resources/ai_app.db") | |
| with open("resources/schema.sql", "r") as f: | |
| db_conn.executescript(f.read()) | |
| return db_conn | |
| db_conn = load_database() | |
| def execute_sql(sql_query): | |
| try: | |
| cursor = db_conn.cursor() | |
| cursor.execute(sql_query) | |
| st.success("SQL query executed successfully!") | |
| return cursor.fetchall() | |
| except Exception as e: | |
| st.info("Database is not supported") | |
| return None | |
| # @st.cache_data | |
| def ask_text2sql(question, context): | |
| if detect(question) != "en": | |
| question = translate_question(question) | |
| # st.write("The question is translated to Vietnamese:") | |
| # st.code(question, language="en") | |
| r = requests.post( | |
| TEXT_2_SQL_API, | |
| json={ | |
| "context": context, | |
| "question": question, | |
| }, | |
| ) | |
| return r.json()["answers"][0] | |
| def translate_question(question): | |
| return translator.translate(question, dest="en").text | |
| def load_example_df(): | |
| example_df = pd.read_csv("resources/examples.csv") | |
| return example_df | |
| def introduction(): | |
| st.title("π Introduction") | |
| st.write("π Welcome to the Text to SQL app!") | |
| st.write( | |
| "π This app allows you to explore the ability of Text to SQL model. The model is CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset." | |
| ) | |
| st.write( | |
| "π The NSText2SQL dataset contains more than 290.000 training samples. Then, the model is evaluated on Spider and vMLP datasets." | |
| ) | |
| st.write("π The other pages in this app include:") | |
| st.write( | |
| " - π EDA Page: This page includes several visualizations to help you understand the two dataset: Spider and vMLP." | |
| ) | |
| st.write( | |
| " - π° Text2SQL Page: This page allows you to generate SQL query from a given question and context." | |
| ) | |
| st.write( | |
| " - π§βπ» About Page: This page provides information about the app and its creators." | |
| ) | |
| st.write( | |
| " - π Reference Page: This page lists the references used in building this app." | |
| ) | |
| # Define a function for the EDA page | |
| def eda(): | |
| st.title("π Dataset Exploration") | |
| # st.subheader("Candlestick Chart") | |
| # fig = go.Figure( | |
| # data=[ | |
| # go.Candlestick( | |
| # x=df["date"], | |
| # open=df["open"], | |
| # high=df["high"], | |
| # low=df["low"], | |
| # close=df["close"], | |
| # increasing_line_color="green", | |
| # decreasing_line_color="red", | |
| # ) | |
| # ], | |
| # layout=go.Layout( | |
| # title="Tesla Stock Price", | |
| # xaxis_title="Date", | |
| # yaxis_title="Price (USD)", | |
| # xaxis_rangeslider_visible=True, | |
| # ), | |
| # ) | |
| # st.plotly_chart(fig) | |
| # st.subheader("Line Chart") | |
| # # Plot the closing price over time | |
| # plot_column = st.selectbox( | |
| # "Select a column", ["open", "close", "low", "high"], index=0 | |
| # ) | |
| # fig = px.line( | |
| # df, x="date", y=plot_column, title=f"Tesla {plot_column} Price Over Time" | |
| # ) | |
| # st.plotly_chart(fig) | |
| # st.subheader("Distribution of Closing Price") | |
| # # Plot the distribution of the closing price | |
| # closing_price_hist = px.histogram( | |
| # df, x="close", nbins=30, title="Distribution of Tesla Closing Price" | |
| # ) | |
| # st.plotly_chart(closing_price_hist) | |
| # st.subheader("Raw Data") | |
| # st.write("You can see the raw data below.") | |
| # # Display the dataset | |
| # st.dataframe(df) | |
| def preprocess_context(context): | |
| context = context.replace("\n", " ").replace("\t", " ").replace("\r", " ") | |
| # Remove multiple spaces | |
| context = re.sub(" +", " ", context) | |
| return context | |
| def examples(): | |
| st.title("Examples") | |
| st.write( | |
| "This page uses CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset to generate SQL query from a given question and context.\nThe examples are listed below" | |
| ) | |
| example_df = load_example_df() | |
| example_tabs = st.tabs([f"Example {i+1}" for i in range(len(example_df))]) | |
| example_btns = [] | |
| with st.sidebar: | |
| # create a blank space | |
| st.write("") | |
| st.write("") | |
| st.write("") | |
| execute_sql_query = st.checkbox( | |
| "Execute SQL query", | |
| ) | |
| num_tries = st.number_input( | |
| "Number of tries", | |
| value=3, | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| ) | |
| for idx, row in example_df.iterrows(): | |
| with example_tabs[idx]: | |
| st.markdown("##### Context:") | |
| st.code(row["context"], language="sql") | |
| st.markdown("##### Question:") | |
| st.text(row["question"]) | |
| example_btns.append(st.button("Generate SQL query", key=f"exp-btn-{idx}")) | |
| if example_btns[idx]: | |
| st.markdown("##### SQL query:") | |
| tries = num_tries | |
| with st.spinner("Generating SQL query..."): | |
| if execute_sql_query: | |
| while tries > 0: | |
| start_time = time.time() | |
| query = ask_text2sql(row["question"], row["context"]) | |
| end_time = time.time() | |
| st.write( | |
| "The SQL query generated by the model in **{:.2f}s** is:".format( | |
| end_time - start_time | |
| ) | |
| ) | |
| st.code(format_sql(query), language="sql") | |
| result = execute_sql(query) | |
| st.write( | |
| "Executing the SQL query yields the following result:" | |
| ) | |
| st.dataframe(pd.DataFrame(result), hide_index=True) | |
| if result is not None: | |
| break | |
| else: | |
| tries -= 1 | |
| else: | |
| start_time = time.time() | |
| query = ask_text2sql(row["question"], row["context"]) | |
| end_time = time.time() | |
| st.markdown( | |
| "The SQL query generated by the model in **{:.2f}s** is:".format( | |
| end_time - start_time | |
| ) | |
| ) | |
| st.code(format_sql(query), language="sql") | |
| # Define a function for the Stock Prediction page | |
| def interactive_demo(): | |
| st.title("Text to SQL using CodeLlama-13b") | |
| st.write( | |
| "This page uses CodeLlama-13b finetuned using QLoRA on NSText2SQL dataset to generate SQL query from a given question and context." | |
| ) | |
| st.subheader("Input") | |
| context_placeholder = st.empty() | |
| question_placeholder = st.empty() | |
| context = context_placeholder.text_area( | |
| "##### Context", | |
| """CREATE TABLE customer (id number, name text, gender text, age number, district_id number; | |
| CREATE TABLE registration (customer_id number, product_id number); | |
| CREATE TABLE district (id number, name text, prefix text, province_id number); | |
| CREATE TABLE province (id number, name text, code text) | |
| CREATE TABLE product (id number, category text, name text, description text, price number, duration number, data_amount number, voice_amount number, sms_amount number);""", | |
| key="context", | |
| height=150, | |
| ) | |
| question = question_placeholder.text_input( | |
| "##### Question", | |
| "Sα» lượng khΓ‘ch hΓ ng cΓ³ Δα» tuα»i tα»« 30 ΔαΊΏn 45 tuα»i?", | |
| key="question", | |
| ) | |
| get_sql_button = st.button("Generate SQL query") | |
| with st.sidebar: | |
| # create a blank space | |
| st.write("") | |
| st.write("") | |
| st.write("") | |
| execute_sql_query = st.checkbox( | |
| "Execute SQL query", | |
| ) | |
| num_tries = st.number_input( | |
| "Number of tries", | |
| value=3, | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| ) | |
| if get_sql_button: | |
| st.markdown("##### Output") | |
| tries = num_tries | |
| if execute_sql_query: | |
| while tries > 0: | |
| start_time = time.time() | |
| query = ask_text2sql(question, context) | |
| end_time = time.time() | |
| st.write( | |
| "The SQL query generated by the model in **{:.2f}s** is:".format( | |
| end_time - start_time | |
| ) | |
| ) | |
| # Display the SQL query in a code block | |
| st.code(format_sql(query), language="sql") | |
| result = execute_sql(query) | |
| st.write("Executing the SQL query yields the following result:") | |
| st.dataframe(pd.DataFrame(result), hide_index=True) | |
| if result is not None: | |
| break | |
| else: | |
| tries -= 1 | |
| else: | |
| start_time = time.time() | |
| query = ask_text2sql(question, context) | |
| end_time = time.time() | |
| st.markdown( | |
| "The SQL query generated by the model in **{:.2f}s** is:".format( | |
| end_time - start_time | |
| ) | |
| ) | |
| # Display the SQL query in a code block | |
| st.code(format_sql(query), language="sql") | |
| # Define a function for the About page | |
| def about(): | |
| st.title("π§βπ» About") | |
| st.write( | |
| "This Streamlit app allows you to explore stock prices and make predictions using an LSTM model." | |
| ) | |
| st.header("Author") | |
| st.write( | |
| "This app was developed by Minh Nam. You can contact the author at [email protected]." | |
| ) | |
| st.header("Data Sources") | |
| st.markdown( | |
| "The Spider dataset was sourced from [Spider](https://yale-lily.github.io/spider)." | |
| ) | |
| st.markdown("The vMLP dataset is a private dataset from Viettel.") | |
| st.header("Acknowledgments") | |
| st.write( | |
| "The author would like to thank Dr. Nguyen Van Nam for his proper guidance, Mr. Nguyen Chi Dong for his support." | |
| ) | |
| st.header("License") | |
| st.write( | |
| # "This app is licensed under the MIT License. See LICENSE.txt for more information." | |
| "N/A" | |
| ) | |
| def references(): | |
| st.title("π References") | |
| st.header( | |
| "References for Text to SQL project using foundation model - CodeLlama-13b" | |
| ) | |
| st.subheader("1. 'Project for time-series data' by AI VIET NAM, et al.") | |
| st.write( | |
| "This organization provides a tutorial on how to build a stock price prediction model using LSTM in the AIO2022 course." | |
| ) | |
| st.write("Link: https://www.facebook.com/aivietnam.edu.vn") | |
| st.subheader( | |
| "2. 'PyTorch LSTMs for time series forecasting of Indian Stocks' by Vinayak Nayak" | |
| ) | |
| st.write( | |
| "This blog post describes how to build a stock price prediction model using LSTM, RNN and CNN-sliding window model." | |
| ) | |
| st.write( | |
| "Link: https://medium.com/analytics-vidhya/pytorch-lstms-for-time-series-forecasting-of-indian-stocks-8a49157da8b9#b052" | |
| ) | |
| st.header("References for Streamlit") | |
| st.subheader("1. Streamlit Documentation") | |
| st.write( | |
| "The official documentation for Streamlit provides detailed information about how to use the library and build Streamlit apps." | |
| ) | |
| st.write("Link: https://docs.streamlit.io/") | |
| st.subheader("2. Streamlit Community") | |
| st.write( | |
| "The Streamlit community includes a forum and a GitHub repository with examples and resources for building Streamlit apps." | |
| ) | |
| st.write( | |
| "Link: https://discuss.streamlit.io/ and https://github.com/streamlit/streamlit/" | |
| ) | |
| # Create the sidebar | |
| st.sidebar.title("Menu") | |
| pages = [ | |
| "Introduction", | |
| # "Datasets", | |
| "Examples", | |
| "Interactive Demo", | |
| "About", | |
| "References", | |
| ] | |
| selected_page = st.sidebar.radio("Go to", pages) | |
| # Show the appropriate page based on the selection | |
| if selected_page == "Introduction": | |
| introduction() | |
| elif selected_page == "EDA": | |
| eda() | |
| elif selected_page == "Examples": | |
| examples() | |
| elif selected_page == "Interactive Demo": | |
| interactive_demo() | |
| elif selected_page == "About": | |
| about() | |
| elif selected_page == "References": | |
| references() | |