nsthorat commited on
Commit
d9bd6c3
·
1 Parent(s): 4e98454
src/data/duckdb_utils.py CHANGED
@@ -1,11 +1,15 @@
1
  """Utils for duckdb."""
2
  import duckdb
3
 
4
- from ..config import CONFIG
5
 
6
 
7
  def duckdb_gcs_setup(con: duckdb.DuckDBPyConnection) -> str:
8
  """Setup DuckDB for GCS."""
 
 
 
 
9
  con.install_extension('httpfs')
10
  con.load_extension('httpfs')
11
 
 
1
  """Utils for duckdb."""
2
  import duckdb
3
 
4
+ from ..config import CONFIG, data_path
5
 
6
 
7
  def duckdb_gcs_setup(con: duckdb.DuckDBPyConnection) -> str:
8
  """Setup DuckDB for GCS."""
9
+ con.execute(f"""
10
+ SET extension_directory='{data_path()}';
11
+ """)
12
+
13
  con.install_extension('httpfs')
14
  con.load_extension('httpfs')
15
 
src/data/sources/json_source.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
  from pydantic import Field as PydanticField
7
  from typing_extensions import override
8
 
9
- from ...config import data_path
10
  from ...schema import Item
11
  from ...utils import download_http_files
12
  from ..duckdb_utils import duckdb_gcs_setup
@@ -39,7 +38,6 @@ class JSONDataset(Source):
39
  # DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
40
  s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
41
 
42
- con.execute(f"""SET extension_directory='{data_path()}';""")
43
  # NOTE: We use duckdb here to increase parallelism for multiple files.
44
  self._df = con.execute(f"""
45
  {duckdb_gcs_setup(con)}
 
6
  from pydantic import Field as PydanticField
7
  from typing_extensions import override
8
 
 
9
  from ...schema import Item
10
  from ...utils import download_http_files
11
  from ..duckdb_utils import duckdb_gcs_setup
 
38
  # DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
39
  s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
40
 
 
41
  # NOTE: We use duckdb here to increase parallelism for multiple files.
42
  self._df = con.execute(f"""
43
  {duckdb_gcs_setup(con)}
src/embeddings/sbert.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  from sentence_transformers import SentenceTransformer
7
  from typing_extensions import override
8
 
 
9
  from ..schema import Item, RichData
10
  from ..signals.signal import TextEmbeddingSignal
11
  from ..signals.splitters.chunk_splitter import split_text
@@ -32,7 +33,8 @@ def _sbert() -> tuple[Optional[str], SentenceTransformer]:
32
  preferred_device = 'mps'
33
  elif not torch.backends.mps.is_built():
34
  log('MPS not available because the current PyTorch install was not built with MPS enabled.')
35
- return preferred_device, SentenceTransformer(MODEL_NAME, device=preferred_device)
 
36
 
37
 
38
  def _optimal_batch_size(preferred_device: Optional[str]) -> int:
 
6
  from sentence_transformers import SentenceTransformer
7
  from typing_extensions import override
8
 
9
+ from ..config import data_path
10
  from ..schema import Item, RichData
11
  from ..signals.signal import TextEmbeddingSignal
12
  from ..signals.splitters.chunk_splitter import split_text
 
33
  preferred_device = 'mps'
34
  elif not torch.backends.mps.is_built():
35
  log('MPS not available because the current PyTorch install was not built with MPS enabled.')
36
+ return preferred_device, SentenceTransformer(
37
+ MODEL_NAME, device=preferred_device, cache_folder=data_path())
38
 
39
 
40
  def _optimal_batch_size(preferred_device: Optional[str]) -> int: