Caleb Fahlgren commited on
Commit
9fc2d21
β€’
1 Parent(s): 4ed656e

add examples and handle incorrect label / data key output better

Browse files
Files changed (2) hide show
  1. README.md +2 -0
  2. app.py +20 -2
README.md CHANGED
@@ -22,3 +22,5 @@ Powered by [Hermes-2-Pro-Llama-3-8B](https://huggingface.co/NousResearch/Hermes-
22
  2. Get SQL DDL from Parquet column and data types
23
  3. Prompt the LLM for SQL
24
  4. Run the SQL against the Dataset Parquet
 
 
 
22
  2. Get SQL DDL from Parquet column and data types
23
  3. Prompt the LLM for SQL
24
  4. Run the SQL against the Dataset Parquet
25
+
26
+ Inspired by [Datasets-Text2SQL](https://huggingface.co/spaces/asoria/datasets-text2sql)
app.py CHANGED
@@ -46,10 +46,12 @@ class SQLResponse(BaseModel):
46
  None, description="The type of visualization to display"
47
  )
48
  data_key: Optional[str] = Field(
49
- None, description="The column name that contains the data for chart responses"
 
50
  )
51
  label_key: Optional[str] = Field(
52
- None, description="The column name that contains the labels for chart responses"
 
53
  )
54
 
55
 
@@ -95,6 +97,9 @@ def generate_query(dataset_id: str, query: str) -> str:
95
  ```
96
 
97
  Please assist the user by writing a SQL query that answers the user's question.
 
 
 
98
  """
99
 
100
  print("Calling LLM with system prompt: ", system_prompt)
@@ -124,6 +129,12 @@ def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.F
124
 
125
  plot = None
126
 
 
 
 
 
 
 
127
  if response.visualization_type == OutputTypes.LINECHART:
128
  plot = df.plot(
129
  kind="line", x=response.label_key, y=response.data_key
@@ -150,6 +161,13 @@ with gr.Blocks() as demo:
150
  value="gretelai/synthetic_text_to_sql",
151
  )
152
  user_query = gr.Textbox("", label="Ask anything...")
 
 
 
 
 
 
 
153
 
154
  btn = gr.Button("Ask πŸͺ„")
155
 
 
46
  None, description="The type of visualization to display"
47
  )
48
  data_key: Optional[str] = Field(
49
+ None,
50
+ description="The column name from the sql query that contains the data for chart responses",
51
  )
52
  label_key: Optional[str] = Field(
53
+ None,
54
+ description="The column name from the sql query that contains the labels for chart responses",
55
  )
56
 
57
 
 
97
  ```
98
 
99
  Please assist the user by writing a SQL query that answers the user's question.
100
+
101
+ Use Label Key as the column name for the x-axis and Data Key as the column name for the y-axis for chart responses. The
102
+ label key and data key must be present in the SQL output.
103
  """
104
 
105
  print("Calling LLM with system prompt: ", system_prompt)
 
129
 
130
  plot = None
131
 
132
+ # handle incorrect data and label keys better
133
+ if response.label_key and response.label_key not in df.columns:
134
+ response.label_key = None
135
+ if response.data_key and response.data_key not in df.columns:
136
+ response.data_key = None
137
+
138
  if response.visualization_type == OutputTypes.LINECHART:
139
  plot = df.plot(
140
  kind="line", x=response.label_key, y=response.data_key
 
161
  value="gretelai/synthetic_text_to_sql",
162
  )
163
  user_query = gr.Textbox("", label="Ask anything...")
164
+ examples = [
165
+ ["Show me a preview of the data"],
166
+ ["Show me something interesting"],
167
+ ["What is the largest length of sql query context?"],
168
+ ["show me counts by sql_query_type in a bar chart"],
169
+ ]
170
+ gr.Examples(examples=examples, inputs=[user_query], outputs=[])
171
 
172
  btn = gr.Button("Ask πŸͺ„")
173