Upload ask.py
Browse files
ask.py
CHANGED
@@ -17,6 +17,7 @@ import duckdb
|
|
17 |
import gradio as gr
|
18 |
import requests
|
19 |
from bs4 import BeautifulSoup
|
|
|
20 |
from dotenv import load_dotenv
|
21 |
from jinja2 import BaseLoader, Environment
|
22 |
from openai import OpenAI
|
@@ -113,6 +114,10 @@ class Ask:
|
|
113 |
else:
|
114 |
self.logger = _get_logger("INFO")
|
115 |
|
|
|
|
|
|
|
|
|
116 |
self.db_con = duckdb.connect(":memory:")
|
117 |
|
118 |
self.db_con.install_extension("vss")
|
@@ -248,15 +253,10 @@ class Ask:
|
|
248 |
|
249 |
return scrape_results
|
250 |
|
251 |
-
def chunk_results(
|
252 |
-
self, scrape_results: Dict[str, str], size: int, overlap: int
|
253 |
-
) -> Dict[str, List[str]]:
|
254 |
chunking_results: Dict[str, List[str]] = {}
|
255 |
for url, text in scrape_results.items():
|
256 |
-
|
257 |
-
for pos in range(0, len(text), size - overlap):
|
258 |
-
chunks.append(text[pos : pos + size])
|
259 |
-
chunking_results[url] = chunks
|
260 |
return chunking_results
|
261 |
|
262 |
def get_embedding(self, client: OpenAI, texts: List[str]) -> List[List[float]]:
|
@@ -304,7 +304,7 @@ CREATE TABLE {table_name} (
|
|
304 |
)
|
305 |
return table_name
|
306 |
|
307 |
-
def save_chunks_to_db(self,
|
308 |
"""
|
309 |
The key of chunking_results is the URL and the value is the list of chunks.
|
310 |
"""
|
@@ -316,10 +316,10 @@ CREATE TABLE {table_name} (
|
|
316 |
table_name = self._create_table()
|
317 |
|
318 |
batches: List[Tuple[str, List[str]]] = []
|
319 |
-
for url, list_chunks in
|
320 |
for i in range(0, len(list_chunks), embed_batch_size):
|
321 |
-
|
322 |
-
batches.append((url,
|
323 |
|
324 |
self.logger.info(f"Embedding {len(batches)} batches of chunks ...")
|
325 |
partial_get_embedding = partial(self.batch_get_embedding, client)
|
@@ -327,9 +327,9 @@ CREATE TABLE {table_name} (
|
|
327 |
all_embeddings = executor.map(partial_get_embedding, batches)
|
328 |
self.logger.info(f"β
Finished embedding.")
|
329 |
|
330 |
-
#
|
331 |
-
#
|
332 |
-
#
|
333 |
for chunk_batch, embeddings in all_embeddings:
|
334 |
url = chunk_batch[0]
|
335 |
list_chunks = chunk_batch[1]
|
@@ -678,19 +678,19 @@ Below is the provided content:
|
|
678 |
if settings.output_mode == OutputMode.answer:
|
679 |
logger.info("Chunking the text ...")
|
680 |
yield "", update_logs()
|
681 |
-
|
682 |
-
|
683 |
-
for url, chunks in
|
684 |
logger.debug(f"URL: {url}")
|
685 |
-
|
686 |
for i, chunk in enumerate(chunks):
|
687 |
-
logger.debug(f"Chunk {i+1}: {chunk}")
|
688 |
-
logger.info(f"β
Generated {
|
689 |
yield "", update_logs()
|
690 |
|
691 |
-
logger.info(f"Saving {
|
692 |
yield "", update_logs()
|
693 |
-
table_name = self.save_chunks_to_db(
|
694 |
logger.info(f"β
Successfully embedded and saved chunks to DB.")
|
695 |
yield "", update_logs()
|
696 |
|
@@ -940,7 +940,6 @@ def launch_gradio(
|
|
940 |
)
|
941 |
@click.option(
|
942 |
"--inference-model-name",
|
943 |
-
"-m",
|
944 |
required=False,
|
945 |
default="gpt-4o-mini",
|
946 |
help="Model name to use for inference",
|
@@ -951,9 +950,10 @@ def launch_gradio(
|
|
951 |
help="Use hybrid search mode with both vector search and full-text search",
|
952 |
)
|
953 |
@click.option(
|
954 |
-
"--
|
|
|
955 |
is_flag=True,
|
956 |
-
help="
|
957 |
)
|
958 |
@click.option(
|
959 |
"-l",
|
@@ -975,7 +975,7 @@ def search_extract_summarize(
|
|
975 |
extract_schema_file: str,
|
976 |
inference_model_name: str,
|
977 |
hybrid_search: bool,
|
978 |
-
|
979 |
log_level: str,
|
980 |
):
|
981 |
load_dotenv(dotenv_path=default_env_file, override=False)
|
@@ -996,7 +996,14 @@ def search_extract_summarize(
|
|
996 |
extract_schema_str=_read_extract_schema_str(extract_schema_file),
|
997 |
)
|
998 |
|
999 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
|
1001 |
share_ui = True
|
1002 |
else:
|
@@ -1007,13 +1014,6 @@ def search_extract_summarize(
|
|
1007 |
share_ui=share_ui,
|
1008 |
logger=logger,
|
1009 |
)
|
1010 |
-
else:
|
1011 |
-
if query is None:
|
1012 |
-
raise Exception("Query is required for the command line mode")
|
1013 |
-
ask = Ask(logger=logger)
|
1014 |
-
|
1015 |
-
final_result = ask.run_query(query=query, settings=settings)
|
1016 |
-
click.echo(final_result)
|
1017 |
|
1018 |
|
1019 |
if __name__ == "__main__":
|
|
|
17 |
import gradio as gr
|
18 |
import requests
|
19 |
from bs4 import BeautifulSoup
|
20 |
+
from chonkie import Chunk, TokenChunker
|
21 |
from dotenv import load_dotenv
|
22 |
from jinja2 import BaseLoader, Environment
|
23 |
from openai import OpenAI
|
|
|
114 |
else:
|
115 |
self.logger = _get_logger("INFO")
|
116 |
|
117 |
+
self.logger.info("Initializing Chonkie ...")
|
118 |
+
self.chunker = TokenChunker(chunk_size=1000, chunk_overlap=100)
|
119 |
+
self.logger.info("β
Successfully initialized Chonkie.")
|
120 |
+
|
121 |
self.db_con = duckdb.connect(":memory:")
|
122 |
|
123 |
self.db_con.install_extension("vss")
|
|
|
253 |
|
254 |
return scrape_results
|
255 |
|
256 |
+
def chunk_results(self, scrape_results: Dict[str, str]) -> Dict[str, List[Chunk]]:
|
|
|
|
|
257 |
chunking_results: Dict[str, List[str]] = {}
|
258 |
for url, text in scrape_results.items():
|
259 |
+
chunking_results[url] = self.chunker.chunk(text)
|
|
|
|
|
|
|
260 |
return chunking_results
|
261 |
|
262 |
def get_embedding(self, client: OpenAI, texts: List[str]) -> List[List[float]]:
|
|
|
304 |
)
|
305 |
return table_name
|
306 |
|
307 |
+
def save_chunks_to_db(self, all_chunks: Dict[str, List[Chunk]]) -> str:
|
308 |
"""
|
309 |
The key of chunking_results is the URL and the value is the list of chunks.
|
310 |
"""
|
|
|
316 |
table_name = self._create_table()
|
317 |
|
318 |
batches: List[Tuple[str, List[str]]] = []
|
319 |
+
for url, list_chunks in all_chunks.items():
|
320 |
for i in range(0, len(list_chunks), embed_batch_size):
|
321 |
+
batch = [chunk.text for chunk in list_chunks[i : i + embed_batch_size]]
|
322 |
+
batches.append((url, batch))
|
323 |
|
324 |
self.logger.info(f"Embedding {len(batches)} batches of chunks ...")
|
325 |
partial_get_embedding = partial(self.batch_get_embedding, client)
|
|
|
327 |
all_embeddings = executor.map(partial_get_embedding, batches)
|
328 |
self.logger.info(f"β
Finished embedding.")
|
329 |
|
330 |
+
# We batch the insert data to speed up the insertion operation.
|
331 |
+
# Although the DuckDB doc says executeMany is optimized for batch insert,
|
332 |
+
# we found that it is faster to batch the insert data and run a single insert.
|
333 |
for chunk_batch, embeddings in all_embeddings:
|
334 |
url = chunk_batch[0]
|
335 |
list_chunks = chunk_batch[1]
|
|
|
678 |
if settings.output_mode == OutputMode.answer:
|
679 |
logger.info("Chunking the text ...")
|
680 |
yield "", update_logs()
|
681 |
+
all_chunks = self.chunk_results(scrape_results)
|
682 |
+
chunk_count = 0
|
683 |
+
for url, chunks in all_chunks.items():
|
684 |
logger.debug(f"URL: {url}")
|
685 |
+
chunk_count += len(chunks)
|
686 |
for i, chunk in enumerate(chunks):
|
687 |
+
logger.debug(f"Chunk {i+1}: {chunk.text}")
|
688 |
+
logger.info(f"β
Generated {chunk_count} chunks ...")
|
689 |
yield "", update_logs()
|
690 |
|
691 |
+
logger.info(f"Saving {chunk_count} chunks to DB ...")
|
692 |
yield "", update_logs()
|
693 |
+
table_name = self.save_chunks_to_db(all_chunks)
|
694 |
logger.info(f"β
Successfully embedded and saved chunks to DB.")
|
695 |
yield "", update_logs()
|
696 |
|
|
|
940 |
)
|
941 |
@click.option(
|
942 |
"--inference-model-name",
|
|
|
943 |
required=False,
|
944 |
default="gpt-4o-mini",
|
945 |
help="Model name to use for inference",
|
|
|
950 |
help="Use hybrid search mode with both vector search and full-text search",
|
951 |
)
|
952 |
@click.option(
|
953 |
+
"--run-cli",
|
954 |
+
"-c",
|
955 |
is_flag=True,
|
956 |
+
help="Run as a command line tool instead of launching the Gradio UI",
|
957 |
)
|
958 |
@click.option(
|
959 |
"-l",
|
|
|
975 |
extract_schema_file: str,
|
976 |
inference_model_name: str,
|
977 |
hybrid_search: bool,
|
978 |
+
run_cli: bool,
|
979 |
log_level: str,
|
980 |
):
|
981 |
load_dotenv(dotenv_path=default_env_file, override=False)
|
|
|
996 |
extract_schema_str=_read_extract_schema_str(extract_schema_file),
|
997 |
)
|
998 |
|
999 |
+
if run_cli:
|
1000 |
+
if query is None:
|
1001 |
+
raise Exception("Query is required for the command line mode")
|
1002 |
+
ask = Ask(logger=logger)
|
1003 |
+
|
1004 |
+
final_result = ask.run_query(query=query, settings=settings)
|
1005 |
+
click.echo(final_result)
|
1006 |
+
else:
|
1007 |
if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
|
1008 |
share_ui = True
|
1009 |
else:
|
|
|
1014 |
share_ui=share_ui,
|
1015 |
logger=logger,
|
1016 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1017 |
|
1018 |
|
1019 |
if __name__ == "__main__":
|