sergeipetrov commited on
Commit
770c645
·
verified ·
1 Parent(s): 1ed1d13

Update src/main.py

Browse files
Files changed (1) hide show
  1. src/main.py +7 -1
src/main.py CHANGED
@@ -21,12 +21,18 @@ from src.models import chunk_config, embed_config, WebhookPayload
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
 
 
26
  TEI_URL = os.getenv("TEI_URL")
 
27
  CHUNKED_DS_NAME = os.getenv("CHUNKED_DS_NAME")
 
28
  EMBED_DS_NAME = os.getenv("EMBED_DS_NAME")
 
29
  INPUT_SPLITS = os.getenv("INPUT_SPLITS")
 
30
  INPUT_TEXT_COL = os.getenv("INPUT_TEXT_COL")
31
 
32
  INPUT_SPLITS = [spl.strip() for spl in INPUT_SPLITS.split(",") if spl]
@@ -183,7 +189,7 @@ def wake_up_endpoint(url):
183
  def embed_dataset(ds_name):
184
  logger.info("Update detected, embedding is scheduled")
185
  wake_up_endpoint(TEI_URL)
186
- input_ds = load_dataset(ds_name, split="+".join(INPUT_SPLITS))
187
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
188
  asyncio.run(embed(input_ds, temp_file))
189
 
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ # you token from Settings
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
 
27
+ # URL of TEI endpoint
28
  TEI_URL = os.getenv("TEI_URL")
29
+ # name of chunked dataset
30
  CHUNKED_DS_NAME = os.getenv("CHUNKED_DS_NAME")
31
+ # name of embeddings dataset
32
  EMBED_DS_NAME = os.getenv("EMBED_DS_NAME")
33
+ # splits of input dataset to process, comma separated
34
  INPUT_SPLITS = os.getenv("INPUT_SPLITS")
35
+ # name of column to load from input dataset
36
  INPUT_TEXT_COL = os.getenv("INPUT_TEXT_COL")
37
 
38
  INPUT_SPLITS = [spl.strip() for spl in INPUT_SPLITS.split(",") if spl]
 
189
  def embed_dataset(ds_name):
190
  logger.info("Update detected, embedding is scheduled")
191
  wake_up_endpoint(TEI_URL)
192
+ input_ds = load_dataset(ds_name, split="train")
193
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
194
  asyncio.run(embed(input_ds, temp_file))
195