sdiazlor HF staff commited on
Commit
d4af916
·
1 Parent(s): 0d14ea5

fix minor errors and improve prompt

Browse files
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -116,7 +116,7 @@ def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm
116
  return (
117
  dataframe,
118
  gr.Dropdown(
119
- choices=["chucks"],
120
  label="Documents column",
121
  value=col_doc,
122
  interactive=(False if col_doc == "" else True),
@@ -170,7 +170,7 @@ def generate_dataset(
170
  progress=gr.Progress(),
171
  ):
172
  num_rows = test_max_num_rows(num_rows)
173
- progress(0.0, desc="Generating questions")
174
  if input_type == "prompt-input":
175
  chunk_generator = get_chunks_generator(
176
  temperature=temperature, is_sample=is_sample
@@ -399,7 +399,9 @@ def push_dataset(
399
  retrieval = "Retrieval" in retrieval_reranking
400
  reranking = "Reranking" in retrieval_reranking
401
 
402
- if input_type != "prompt-input":
 
 
403
  dataframe, _ = load_dataset_file(
404
  repo_id=original_repo_id,
405
  file_paths=file_paths,
@@ -522,8 +524,12 @@ def push_dataset(
522
  )
523
 
524
  for item in ["context", "question", "response"]:
525
- dataframe[f"{item}_length"] = dataframe[item].apply(len)
526
- dataframe[f"{item}_embeddings"] = get_embeddings(dataframe[item].to_list())
 
 
 
 
527
 
528
  rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
529
  if rg_dataset is None:
 
116
  return (
117
  dataframe,
118
  gr.Dropdown(
119
+ choices=["chunks"],
120
  label="Documents column",
121
  value=col_doc,
122
  interactive=(False if col_doc == "" else True),
 
170
  progress=gr.Progress(),
171
  ):
172
  num_rows = test_max_num_rows(num_rows)
173
+ progress(0.0, desc="Initializing dataset generation")
174
  if input_type == "prompt-input":
175
  chunk_generator = get_chunks_generator(
176
  temperature=temperature, is_sample=is_sample
 
399
  retrieval = "Retrieval" in retrieval_reranking
400
  reranking = "Reranking" in retrieval_reranking
401
 
402
+ if input_type == "prompt-input":
403
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
404
+ else:
405
  dataframe, _ = load_dataset_file(
406
  repo_id=original_repo_id,
407
  file_paths=file_paths,
 
524
  )
525
 
526
  for item in ["context", "question", "response"]:
527
+ dataframe[f"{item}_length"] = dataframe[item].apply(
528
+ lambda x: len(x) if x is not None else 0
529
+ )
530
+ dataframe[f"{item}_embeddings"] = get_embeddings(
531
+ dataframe[item].apply(lambda x: x if x is not None else "").to_list()
532
+ )
533
 
534
  rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
535
  if rg_dataset is None:
src/synthetic_dataset_generator/pipelines/rag.py CHANGED
@@ -18,11 +18,11 @@ DEFAULT_DATASET_DESCRIPTIONS = [
18
 
19
  PROMPT_CREATION_PROMPT = """
20
 
21
- You are an AI assistant specialized in designing retrieval-augmented generation (RAG) tasks for dataset creation.
22
 
23
- Your task is to generate a well-structured and descriptive prompt based on the provided dataset description and company context. Respond with only the generated prompt and nothing else.
24
 
25
- The prompt should closely follow the style and structure of the example prompts below. Ensure that you include all relevant details from the dataset description and reflect the company context accurately.
26
 
27
  Description: A dataset to retrieve information from legal documents.
28
  Output: A dataset to retrieve information from a collection of legal documents related to the US law system and the status of contracts.
@@ -48,9 +48,9 @@ Do not include or reference the retrieval task itself in the generated chunks.
48
 
49
  CHUNKS_TEMPLATE = """You have been assigned to generate text chunks based on the following retrieval task: {{ task }}.
50
 
51
- Provide only the text chunks without explaining your process or reasoning.
52
 
53
- Ensure the chunks are clear, accurate, and directly relevant to the task.
54
 
55
  Use your general knowledge to create informative and precise outputs.
56
  """
@@ -145,12 +145,12 @@ def generate_pipeline_code(
145
  retrieval_reranking: list[str],
146
  num_rows: int = 10,
147
  ) -> str:
148
- if repo_id is None:
149
- subset = "default"
150
- split = "train"
151
- else:
152
  subset = get_dataset_config_names(repo_id)[0]
153
  split = get_dataset_split_names(repo_id, subset)[0]
 
 
 
154
  retrieval = "Retrieval" in retrieval_reranking
155
  reranking = "Reranking" in retrieval_reranking
156
  base_code = f"""
 
18
 
19
  PROMPT_CREATION_PROMPT = """
20
 
21
+ You are an AI assistant specialized in designing retrieval-augmented generation (RAG) tasks for dataset generation.
22
 
23
+ Your task is to generate a well-structured and descriptive prompt based on the provided dataset description. Respond with only the generated prompt and nothing else.
24
 
25
+ The prompt should closely follow the style and structure of the example prompts below. Ensure that you include all relevant details from the dataset description.
26
 
27
  Description: A dataset to retrieve information from legal documents.
28
  Output: A dataset to retrieve information from a collection of legal documents related to the US law system and the status of contracts.
 
48
 
49
  CHUNKS_TEMPLATE = """You have been assigned to generate text chunks based on the following retrieval task: {{ task }}.
50
 
51
+ Provide only the text chunks without explaining your process or reasoning. Do not include any additional information. Do not indicate that it is a text chunk.
52
 
53
+ Ensure the chunks are concise, clear, and directly relevant to the task.
54
 
55
  Use your general knowledge to create informative and precise outputs.
56
  """
 
145
  retrieval_reranking: list[str],
146
  num_rows: int = 10,
147
  ) -> str:
148
+ if input_type == "dataset-input" and repo_id is not None:
 
 
 
149
  subset = get_dataset_config_names(repo_id)[0]
150
  split = get_dataset_split_names(repo_id, subset)[0]
151
+ else:
152
+ subset = "default"
153
+ split = "train"
154
  retrieval = "Retrieval" in retrieval_reranking
155
  reranking = "Reranking" in retrieval_reranking
156
  base_code = f"""