LeetTools commited on
Commit
6eabcb6
Β·
verified Β·
1 Parent(s): 7f2676a

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +34 -34
ask.py CHANGED
@@ -17,6 +17,7 @@ import duckdb
17
  import gradio as gr
18
  import requests
19
  from bs4 import BeautifulSoup
 
20
  from dotenv import load_dotenv
21
  from jinja2 import BaseLoader, Environment
22
  from openai import OpenAI
@@ -113,6 +114,10 @@ class Ask:
113
  else:
114
  self.logger = _get_logger("INFO")
115
 
 
 
 
 
116
  self.db_con = duckdb.connect(":memory:")
117
 
118
  self.db_con.install_extension("vss")
@@ -248,15 +253,10 @@ class Ask:
248
 
249
  return scrape_results
250
 
251
- def chunk_results(
252
- self, scrape_results: Dict[str, str], size: int, overlap: int
253
- ) -> Dict[str, List[str]]:
254
  chunking_results: Dict[str, List[str]] = {}
255
  for url, text in scrape_results.items():
256
- chunks = []
257
- for pos in range(0, len(text), size - overlap):
258
- chunks.append(text[pos : pos + size])
259
- chunking_results[url] = chunks
260
  return chunking_results
261
 
262
  def get_embedding(self, client: OpenAI, texts: List[str]) -> List[List[float]]:
@@ -304,7 +304,7 @@ CREATE TABLE {table_name} (
304
  )
305
  return table_name
306
 
307
- def save_chunks_to_db(self, chunking_results: Dict[str, List[str]]) -> str:
308
  """
309
  The key of chunking_results is the URL and the value is the list of chunks.
310
  """
@@ -316,10 +316,10 @@ CREATE TABLE {table_name} (
316
  table_name = self._create_table()
317
 
318
  batches: List[Tuple[str, List[str]]] = []
319
- for url, list_chunks in chunking_results.items():
320
  for i in range(0, len(list_chunks), embed_batch_size):
321
- list_chunks = list_chunks[i : i + embed_batch_size]
322
- batches.append((url, list_chunks))
323
 
324
  self.logger.info(f"Embedding {len(batches)} batches of chunks ...")
325
  partial_get_embedding = partial(self.batch_get_embedding, client)
@@ -327,9 +327,9 @@ CREATE TABLE {table_name} (
327
  all_embeddings = executor.map(partial_get_embedding, batches)
328
  self.logger.info(f"βœ… Finished embedding.")
329
 
330
- # we batch the insert data to speed up the insertion operation
331
- # although the DuckDB doc says executeMany is optimized for batch insert
332
- # but we found that it is faster to batch the insert data and run a single insert
333
  for chunk_batch, embeddings in all_embeddings:
334
  url = chunk_batch[0]
335
  list_chunks = chunk_batch[1]
@@ -678,19 +678,19 @@ Below is the provided content:
678
  if settings.output_mode == OutputMode.answer:
679
  logger.info("Chunking the text ...")
680
  yield "", update_logs()
681
- chunking_results = self.chunk_results(scrape_results, 1000, 100)
682
- total_chunks = 0
683
- for url, chunks in chunking_results.items():
684
  logger.debug(f"URL: {url}")
685
- total_chunks += len(chunks)
686
  for i, chunk in enumerate(chunks):
687
- logger.debug(f"Chunk {i+1}: {chunk}")
688
- logger.info(f"βœ… Generated {total_chunks} chunks ...")
689
  yield "", update_logs()
690
 
691
- logger.info(f"Saving {total_chunks} chunks to DB ...")
692
  yield "", update_logs()
693
- table_name = self.save_chunks_to_db(chunking_results)
694
  logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
695
  yield "", update_logs()
696
 
@@ -940,7 +940,6 @@ def launch_gradio(
940
  )
941
  @click.option(
942
  "--inference-model-name",
943
- "-m",
944
  required=False,
945
  default="gpt-4o-mini",
946
  help="Model name to use for inference",
@@ -951,9 +950,10 @@ def launch_gradio(
951
  help="Use hybrid search mode with both vector search and full-text search",
952
  )
953
  @click.option(
954
- "--web-ui",
 
955
  is_flag=True,
956
- help="Launch the web interface",
957
  )
958
  @click.option(
959
  "-l",
@@ -975,7 +975,7 @@ def search_extract_summarize(
975
  extract_schema_file: str,
976
  inference_model_name: str,
977
  hybrid_search: bool,
978
- web_ui: bool,
979
  log_level: str,
980
  ):
981
  load_dotenv(dotenv_path=default_env_file, override=False)
@@ -996,7 +996,14 @@ def search_extract_summarize(
996
  extract_schema_str=_read_extract_schema_str(extract_schema_file),
997
  )
998
 
999
- if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
 
 
 
 
 
 
 
1000
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
1001
  share_ui = True
1002
  else:
@@ -1007,13 +1014,6 @@ def search_extract_summarize(
1007
  share_ui=share_ui,
1008
  logger=logger,
1009
  )
1010
- else:
1011
- if query is None:
1012
- raise Exception("Query is required for the command line mode")
1013
- ask = Ask(logger=logger)
1014
-
1015
- final_result = ask.run_query(query=query, settings=settings)
1016
- click.echo(final_result)
1017
 
1018
 
1019
  if __name__ == "__main__":
 
17
  import gradio as gr
18
  import requests
19
  from bs4 import BeautifulSoup
20
+ from chonkie import Chunk, TokenChunker
21
  from dotenv import load_dotenv
22
  from jinja2 import BaseLoader, Environment
23
  from openai import OpenAI
 
114
  else:
115
  self.logger = _get_logger("INFO")
116
 
117
+ self.logger.info("Initializing Chonkie ...")
118
+ self.chunker = TokenChunker(chunk_size=1000, chunk_overlap=100)
119
+ self.logger.info("βœ… Successfully initialized Chonkie.")
120
+
121
  self.db_con = duckdb.connect(":memory:")
122
 
123
  self.db_con.install_extension("vss")
 
253
 
254
  return scrape_results
255
 
256
+ def chunk_results(self, scrape_results: Dict[str, str]) -> Dict[str, List[Chunk]]:
 
 
257
  chunking_results: Dict[str, List[str]] = {}
258
  for url, text in scrape_results.items():
259
+ chunking_results[url] = self.chunker.chunk(text)
 
 
 
260
  return chunking_results
261
 
262
  def get_embedding(self, client: OpenAI, texts: List[str]) -> List[List[float]]:
 
304
  )
305
  return table_name
306
 
307
+ def save_chunks_to_db(self, all_chunks: Dict[str, List[Chunk]]) -> str:
308
  """
309
  The key of chunking_results is the URL and the value is the list of chunks.
310
  """
 
316
  table_name = self._create_table()
317
 
318
  batches: List[Tuple[str, List[str]]] = []
319
+ for url, list_chunks in all_chunks.items():
320
  for i in range(0, len(list_chunks), embed_batch_size):
321
+ batch = [chunk.text for chunk in list_chunks[i : i + embed_batch_size]]
322
+ batches.append((url, batch))
323
 
324
  self.logger.info(f"Embedding {len(batches)} batches of chunks ...")
325
  partial_get_embedding = partial(self.batch_get_embedding, client)
 
327
  all_embeddings = executor.map(partial_get_embedding, batches)
328
  self.logger.info(f"βœ… Finished embedding.")
329
 
330
+ # We batch the insert data to speed up the insertion operation.
331
+ # Although the DuckDB doc says executeMany is optimized for batch insert,
332
+ # we found that it is faster to batch the insert data and run a single insert.
333
  for chunk_batch, embeddings in all_embeddings:
334
  url = chunk_batch[0]
335
  list_chunks = chunk_batch[1]
 
678
  if settings.output_mode == OutputMode.answer:
679
  logger.info("Chunking the text ...")
680
  yield "", update_logs()
681
+ all_chunks = self.chunk_results(scrape_results)
682
+ chunk_count = 0
683
+ for url, chunks in all_chunks.items():
684
  logger.debug(f"URL: {url}")
685
+ chunk_count += len(chunks)
686
  for i, chunk in enumerate(chunks):
687
+ logger.debug(f"Chunk {i+1}: {chunk.text}")
688
+ logger.info(f"βœ… Generated {chunk_count} chunks ...")
689
  yield "", update_logs()
690
 
691
+ logger.info(f"Saving {chunk_count} chunks to DB ...")
692
  yield "", update_logs()
693
+ table_name = self.save_chunks_to_db(all_chunks)
694
  logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
695
  yield "", update_logs()
696
 
 
940
  )
941
  @click.option(
942
  "--inference-model-name",
 
943
  required=False,
944
  default="gpt-4o-mini",
945
  help="Model name to use for inference",
 
950
  help="Use hybrid search mode with both vector search and full-text search",
951
  )
952
  @click.option(
953
+ "--run-cli",
954
+ "-c",
955
  is_flag=True,
956
+ help="Run as a command line tool instead of launching the Gradio UI",
957
  )
958
  @click.option(
959
  "-l",
 
975
  extract_schema_file: str,
976
  inference_model_name: str,
977
  hybrid_search: bool,
978
+ run_cli: bool,
979
  log_level: str,
980
  ):
981
  load_dotenv(dotenv_path=default_env_file, override=False)
 
996
  extract_schema_str=_read_extract_schema_str(extract_schema_file),
997
  )
998
 
999
+ if run_cli:
1000
+ if query is None:
1001
+ raise Exception("Query is required for the command line mode")
1002
+ ask = Ask(logger=logger)
1003
+
1004
+ final_result = ask.run_query(query=query, settings=settings)
1005
+ click.echo(final_result)
1006
+ else:
1007
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
1008
  share_ui = True
1009
  else:
 
1014
  share_ui=share_ui,
1015
  logger=logger,
1016
  )
 
 
 
 
 
 
 
1017
 
1018
 
1019
  if __name__ == "__main__":