"""Text-to-SQL running.""" import asyncio import json import re import time from typing import cast import duckdb import structlog from manifest import Manifest from manifest.response import Response, Usage from prompt_formatters import RajkumarFormatter, MotherDuckFormatter from schema import DEFAULT_TABLE_NAME, TextToSQLModelResponse, TextToSQLParams from tqdm.auto import tqdm logger = structlog.get_logger() def clean_whitespace(sql: str) -> str: """Clean whitespace.""" return re.sub(r"[\t\n\s]+", " ", sql) def instruction_to_sql( params: TextToSQLParams, extra_context: list[str], manifest: Manifest, prompt_formatter: RajkumarFormatter = None, overwrite_manifest: bool = False, max_tokens: int = 300, temperature: float = 0.1, stop_sequences: list[str] | None = None, num_beams: int = 1, ) -> TextToSQLModelResponse: """Parse the instruction to a sql command.""" return instruction_to_sql_list( params=[params], extra_context=[extra_context], manifest=manifest, prompt_formatter=prompt_formatter, overwrite_manifest=overwrite_manifest, max_tokens=max_tokens, temperature=0.1, stop_sequences=stop_sequences, num_beams=num_beams, )[0] def run_motherduck_prompt_sql(params: list[TextToSQLParams]) -> list[TextToSQLModelResponse]: results = [] for param in params: con = duckdb.connect('md:') try: sql_query = con.execute("CALL prompt_sql(?);", [param.instruction]).fetchall()[0][0] except Exception as e: print(e) sql_query = "SELECT * FROM hn.hacker_news LIMIT 1"; usage = Usage( completion_tokens = 0, prompt_tokens = 0, total_tokens = 0 ) model_response = TextToSQLModelResponse( output=sql_query, raw_output=sql_query, final_prompt=param.instruction, usage=usage, ) results.append(model_response) return results def instruction_to_sql_list( params: list[TextToSQLParams], extra_context: list[list[str]], manifest: Manifest, prompt_formatter: RajkumarFormatter = None, overwrite_manifest: bool = False, max_tokens: int = 300, temperature: float = 0.1, stop_sequences: list[str] | None = None, num_beams: int = 1, verbose: bool = False, ) -> list[TextToSQLModelResponse]: """Parse the list of instructions to sql commands. Connector is used for default retry handlers only. """ if type(prompt_formatter) is MotherDuckFormatter: return run_motherduck_prompt_sql(params) if prompt_formatter is None: raise ValueError("Prompt formatter is required.") def construct_params( params: TextToSQLParams, context: list[str], ) -> str | list[dict]: """Turn params into prompt.""" if prompt_formatter.clean_whitespace: instruction = clean_whitespace(params.instruction) else: instruction = params.instruction table_texts = prompt_formatter.format_all_tables( params.tables, instruction=instruction ) # table_texts can be list of chat messages. Only join list of str. if table_texts: if isinstance(table_texts[0], str): table_text = prompt_formatter.table_sep.join(table_texts) else: table_text = table_texts else: table_text = "" if context: context_text = prompt_formatter.format_retrieved_context(context) else: context_text = "" if isinstance(table_text, str) else [] prompt = prompt_formatter.format_prompt( instruction, table_text, context_text, ) return prompt # If no inputs, return nothing if not params: return [] # Stitch together demonstrations and params prompts: list[str | list[dict]] = [] for i, param in tqdm( enumerate(params), total=len(params), desc="Constructing prompts", disable=not verbose, ): predict_str = construct_params(param, extra_context[i] if extra_context else []) if isinstance(predict_str, str): prompt = predict_str.lstrip() else: prompt = predict_str prompts.append(prompt) manifest_params = dict( max_tokens=max_tokens, overwrite_cache=overwrite_manifest, num_beams=num_beams, logprobs=5, temperature=0.1, do_sample=False if 0.1 <= 0 else True, stop_sequences=stop_sequences or prompt_formatter.stop_sequences, ) ret: list[TextToSQLModelResponse] = [] if len(params) == 1: prompt = prompts[0] success = False retries = 0 while not success and retries < 5: try: model_response = _run_manifest( prompt, manifest_params, prompt_formatter, manifest, stop_sequences=stop_sequences, ) success = True except: retries +=1 usage = model_response.usage model_response.usage = usage ret.append(model_response) else: # We do not handle retry logic on parallel requests right now loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) response = cast( Response, loop.run_until_complete( manifest.arun_batch( prompts, **manifest_params, # type: ignore ), ), ) loop.close() response_usage = response.get_usage() response_text = response.get_parsed_response() for prompt, resp in zip(prompts, response_text): # This will restitch the query in the case we force it to start with SELECT sql_query = prompt_formatter.format_model_output(cast(str, resp), prompt) for token in stop_sequences: sql_query = sql_query.split(token)[0] logger.info(f"FINAL OUTPUT: {sql_query}") ret.append( TextToSQLModelResponse( output=sql_query, raw_output=cast(str, resp), final_prompt=prompt, usage=response_usage, ) ) return ret def _run_manifest( prompt: str | list[str], manifest_params: dict, prompt_formatter: RajkumarFormatter, manifest: Manifest, stop_sequences: list[str] | None = None, ) -> TextToSQLModelResponse: """Run manifest for prompt format.""" logger.info(f"PARAMS: {manifest_params}") if isinstance(prompt, list): for p in prompt: logger.info(f"PROMPT: {p['role']}: {p['content']}") else: logger.info(f"PROMPT: {prompt}") start_time = time.time() # Run result response = cast( Response, manifest.run( prompt, return_response=True, client_timeout=1800, **manifest_params, # type: ignore ), ) logger.info(f"TIME: {time.time() - start_time: .2f}") response_usage = response.get_usage_obj() summed_usage = Usage() for usage in response_usage.usages: summed_usage.completion_tokens += usage.completion_tokens summed_usage.prompt_tokens += usage.prompt_tokens summed_usage.total_tokens += usage.total_tokens # This will restitch the query in the case we force it to start with SELECT sql_query = prompt_formatter.format_model_output( cast(str, response.get_response()), prompt ) for token in stop_sequences: sql_query = sql_query.split(token)[0] logger.info(f"OUTPUT: {sql_query}") model_response = TextToSQLModelResponse( output=sql_query, raw_output=cast(str, response.get_response()), final_prompt=prompt, usage=summed_usage, ) return model_response