Upload 5 files
Browse files- scripts/evaluate_factual_robustness.py +59 -40
- scripts/get_prediction_file.py +58 -0
- scripts/get_prediction_result.py +2 -2
- scripts/get_scores.py +21 -17
- scripts/process_data.py +33 -9
scripts/evaluate_factual_robustness.py
CHANGED
@@ -1,30 +1,31 @@
|
|
1 |
import json
|
2 |
import tqdm
|
3 |
import logging
|
4 |
-
from scripts.
|
5 |
from scripts.groq_client import GroqClient
|
6 |
-
from scripts.helper import adaptive_delay, ensure_directory_exists, load_used_data
|
7 |
from scripts.prompt import get_factual_prompt
|
8 |
|
9 |
def evaluate_factual_robustness(config):
|
10 |
-
"""Evaluates negative rejection for a given model
|
11 |
-
config['noise_rate'] = 0.4 # Time being to do clarification
|
12 |
model_name = config['model_name']
|
13 |
|
14 |
if model_name in config['models']:
|
15 |
-
model = GroqClient(plm=
|
16 |
else:
|
17 |
logging.warning(f"Skipping unknown model: {model_name}")
|
18 |
return
|
19 |
|
20 |
-
#
|
|
|
|
|
|
|
|
|
|
|
21 |
base_path = "results/Counterfactual Robustness"
|
22 |
-
evalue_file = get_factual_evaluation(config)
|
23 |
-
print(f"Factual pred file {evalue_file}")
|
24 |
-
output_file = f"{base_path}/output_{config['output_file_extension']}.json"
|
25 |
result_file = f"{base_path}/scores_{config['output_file_extension']}.json"
|
26 |
-
|
27 |
-
|
28 |
def process_query(model, data, used_data, output_file):
|
29 |
"""Processes a single query, generates evaluation, and writes the result."""
|
30 |
if data['id'] in used_data and data['query'] == used_data[data['id']]['query'] and data['ans'] == used_data[data['id']]['ans']:
|
@@ -33,8 +34,7 @@ def evaluate_factual_robustness(config):
|
|
33 |
|
34 |
try:
|
35 |
instruction = get_factual_prompt(data['query'], data['prediction'])
|
36 |
-
|
37 |
-
# Retry mechanism for evaluation
|
38 |
for attempt in range(1, 4):
|
39 |
evaluation = model.generate(instruction)
|
40 |
if evaluation:
|
@@ -42,7 +42,7 @@ def evaluate_factual_robustness(config):
|
|
42 |
adaptive_delay(attempt)
|
43 |
|
44 |
data['evaluation'] = evaluation
|
45 |
-
|
46 |
output_file.write(json.dumps(data, ensure_ascii=False) + '\n')
|
47 |
return data
|
48 |
|
@@ -50,7 +50,7 @@ def evaluate_factual_robustness(config):
|
|
50 |
print(f"Error processing query: {e}")
|
51 |
return None
|
52 |
|
53 |
-
def calculate_scores(results,
|
54 |
"""Calculates and returns rejection rates and other metrics."""
|
55 |
rejecttt = 0
|
56 |
tt = 0
|
@@ -64,35 +64,54 @@ def evaluate_factual_robustness(config):
|
|
64 |
tt += 1
|
65 |
|
66 |
scores = {
|
67 |
-
'reject_rate': rejecttt/len(results),
|
68 |
-
'all_rate':
|
69 |
-
'correct_rate': correct_tt/rejecttt if rejecttt > 0 else 0,
|
70 |
-
'tt':tt,
|
71 |
-
'rejecttt':rejecttt,
|
72 |
-
'correct_tt':correct_tt,
|
73 |
'nums': len(results),
|
74 |
-
'noise_rate':
|
|
|
75 |
}
|
76 |
return scores
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
96 |
|
97 |
with open(result_file, 'w', encoding='utf-8') as f_result:
|
98 |
-
json.dump(
|
|
|
1 |
import json
|
2 |
import tqdm
|
3 |
import logging
|
4 |
+
from scripts.get_prediction_file import get_prediction_file
|
5 |
from scripts.groq_client import GroqClient
|
6 |
+
from scripts.helper import adaptive_delay, ensure_directory_exists, load_used_data, update_config
|
7 |
from scripts.prompt import get_factual_prompt
|
8 |
|
9 |
def evaluate_factual_robustness(config):
|
10 |
+
"""Evaluates negative rejection for a given model under multiple correct_rate/noise_rate conditions."""
|
|
|
11 |
model_name = config['model_name']
|
12 |
|
13 |
if model_name in config['models']:
|
14 |
+
model = GroqClient(plm=model_name)
|
15 |
else:
|
16 |
logging.warning(f"Skipping unknown model: {model_name}")
|
17 |
return
|
18 |
|
19 |
+
# Define the conditions to test
|
20 |
+
conditions = [
|
21 |
+
{"correct_rate": 1.0, "noise_rate": 0.2, "label": "factual_only"}, # factual documents with some noisy documents
|
22 |
+
{"correct_rate": 0.0, "noise_rate": 0.4, "label": "counterfactual"} # Counterfactual + noise
|
23 |
+
]
|
24 |
+
|
25 |
base_path = "results/Counterfactual Robustness"
|
|
|
|
|
|
|
26 |
result_file = f"{base_path}/scores_{config['output_file_extension']}.json"
|
27 |
+
final_scores = {"conditions": []}
|
28 |
+
|
29 |
def process_query(model, data, used_data, output_file):
|
30 |
"""Processes a single query, generates evaluation, and writes the result."""
|
31 |
if data['id'] in used_data and data['query'] == used_data[data['id']]['query'] and data['ans'] == used_data[data['id']]['ans']:
|
|
|
34 |
|
35 |
try:
|
36 |
instruction = get_factual_prompt(data['query'], data['prediction'])
|
37 |
+
#eval_model = GroqClient(plm='llama3-70b-8192')
|
|
|
38 |
for attempt in range(1, 4):
|
39 |
evaluation = model.generate(instruction)
|
40 |
if evaluation:
|
|
|
42 |
adaptive_delay(attempt)
|
43 |
|
44 |
data['evaluation'] = evaluation
|
45 |
+
logging.info(f"Model Response for Factual robustness: {evaluation}")
|
46 |
output_file.write(json.dumps(data, ensure_ascii=False) + '\n')
|
47 |
return data
|
48 |
|
|
|
50 |
print(f"Error processing query: {e}")
|
51 |
return None
|
52 |
|
53 |
+
def calculate_scores(results, condition):
|
54 |
"""Calculates and returns rejection rates and other metrics."""
|
55 |
rejecttt = 0
|
56 |
tt = 0
|
|
|
64 |
tt += 1
|
65 |
|
66 |
scores = {
|
67 |
+
'reject_rate': rejecttt / len(results) if len(results) > 0 else 0, #Error Detection Rate (ED)
|
68 |
+
'all_rate': tt / len(results) if len(results) > 0 else 0,
|
69 |
+
'correct_rate': correct_tt / rejecttt if rejecttt > 0 else 0, #Error Correction Rate (CR)
|
70 |
+
'tt': tt,
|
71 |
+
'rejecttt': rejecttt,
|
72 |
+
'correct_tt': correct_tt,
|
73 |
'nums': len(results),
|
74 |
+
'noise_rate': condition['noise_rate'],
|
75 |
+
'condition_label': condition['label']
|
76 |
}
|
77 |
return scores
|
78 |
+
|
79 |
+
for condition in conditions:
|
80 |
+
logging.info(f"\nEvaluating condition: {condition['label']} (correct_rate={condition['correct_rate']}, noise_rate={condition['noise_rate']})")
|
81 |
+
|
82 |
+
# Update config with current condition's noise_rate
|
83 |
+
config['noise_rate'] = condition['noise_rate']
|
84 |
+
#config['passage_num'] = 10
|
85 |
+
update_config(config)
|
86 |
+
|
87 |
+
# File paths with condition-specific suffixes
|
88 |
+
pred_file = get_prediction_file(config, condition['correct_rate'])
|
89 |
+
output_file = f"{base_path}/output_{config['output_file_extension']}.json"
|
90 |
+
|
91 |
+
ensure_directory_exists(output_file)
|
92 |
+
|
93 |
+
logging.info(f"Factual pred file for {condition['label']}: {pred_file}")
|
94 |
+
|
95 |
+
# Load or recalculate data
|
96 |
+
used_data = []
|
97 |
+
results = []
|
98 |
+
if config['UsePreCalculatedValue']:
|
99 |
+
logging.info(f"Trying to use pre-calculated values for {condition['label']}")
|
100 |
+
used_data = load_used_data(output_file)
|
101 |
+
else:
|
102 |
+
logging.info(f"Recalculating the metrics for {condition['label']}...")
|
103 |
|
104 |
+
with open(output_file, 'w', encoding='utf-8') as f_out, open(pred_file, 'r', encoding='utf-8') as f_eval:
|
105 |
+
for line in tqdm.tqdm(f_eval):
|
106 |
+
data = json.loads(line)
|
107 |
+
processed_data = process_query(model, data, used_data, f_out)
|
108 |
+
if processed_data:
|
109 |
+
results.append(processed_data)
|
110 |
|
111 |
+
# Compute and save scores
|
112 |
+
scores = calculate_scores(results, condition)
|
113 |
+
final_scores["conditions"].append(scores)
|
114 |
+
logging.info(f"Counterfactual Robustness Score for {condition['label']}: {scores}")
|
115 |
|
116 |
with open(result_file, 'w', encoding='utf-8') as f_result:
|
117 |
+
json.dump(final_scores, f_result, ensure_ascii=False, indent=4)
|
scripts/get_prediction_file.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
from scripts.get_prediction_result import get_prediction_result
|
5 |
+
from scripts.helper import ensure_directory_exists, load_dataset
|
6 |
+
|
7 |
+
|
8 |
+
# Set up logging configuration
|
9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
+
|
11 |
+
# Improved function to evaluate noise robustness
|
12 |
+
def get_prediction_file(config, correct_rate = 0):
|
13 |
+
result_path = config['result_path'] + 'Counterfactual Robustness/'
|
14 |
+
noise_rate = config['noise_rate']
|
15 |
+
|
16 |
+
# Iterate over each model specified in the config
|
17 |
+
filename = os.path.join(result_path, f"prediction_{config['output_file_extension']}.json")
|
18 |
+
ensure_directory_exists(filename)
|
19 |
+
|
20 |
+
results = get_prediction_result(config, config['factual_file_name'], filename, correct_rate) # Store results for this model
|
21 |
+
|
22 |
+
# Save results to a file
|
23 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
24 |
+
for result in results:
|
25 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
26 |
+
|
27 |
+
return filename
|
28 |
+
# Compute per-model noise robustness
|
29 |
+
'''tt = sum(1 for i in results if (noise_rate == 1 and i['label'][0] == -1) or (0 not in i['label'] and 1 in i['label']))
|
30 |
+
scores = {
|
31 |
+
'all_rate': (tt)/len(results),
|
32 |
+
'noise_rate': noise_rate,
|
33 |
+
'tt':tt,
|
34 |
+
'nums': len(results),
|
35 |
+
}
|
36 |
+
fact_tt = 0
|
37 |
+
correct_tt = 0
|
38 |
+
for i in results:
|
39 |
+
if i['factlabel'] == 1:
|
40 |
+
fact_tt += 1
|
41 |
+
if 0 not in i['label']:
|
42 |
+
correct_tt += 1
|
43 |
+
fact_check_rate = fact_tt/len(results)
|
44 |
+
if fact_tt > 0:
|
45 |
+
correct_rate = correct_tt/fact_tt
|
46 |
+
else:
|
47 |
+
correct_rate = 0
|
48 |
+
scores['fact_check_rate'] = fact_check_rate
|
49 |
+
scores['correct_rate'] = correct_rate
|
50 |
+
scores['fact_tt'] = fact_tt
|
51 |
+
scores['correct_tt'] = correct_tt
|
52 |
+
|
53 |
+
#logging.info(f"score: {scores}")
|
54 |
+
score_filename = os.path.join(result_path, f"scores_{config['output_file_extension']}.json")
|
55 |
+
with open(score_filename, 'w') as f:
|
56 |
+
json.dump(scores, f, ensure_ascii=False, indent=4)'''
|
57 |
+
|
58 |
+
|
scripts/get_prediction_result.py
CHANGED
@@ -9,7 +9,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
9 |
|
10 |
# Get prediction from LLM based on different dataset
|
11 |
|
12 |
-
def get_prediction_result(config, data_file_name, prediction_file_name=''):
|
13 |
results = []
|
14 |
used_data = []
|
15 |
dataset = load_dataset(data_file_name)
|
@@ -37,7 +37,7 @@ def get_prediction_result(config, data_file_name, prediction_file_name=''):
|
|
37 |
continue
|
38 |
|
39 |
logging.info(f"Executing Query {idx + 1} for Model: {modelname}")
|
40 |
-
query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name)
|
41 |
|
42 |
# Retry mechanism for prediction
|
43 |
for attempt in range(1, config['retry_attempts'] + 1):
|
|
|
9 |
|
10 |
# Get prediction from LLM based on different dataset
|
11 |
|
12 |
+
def get_prediction_result(config, data_file_name, prediction_file_name='', correct_rate = 0):
|
13 |
results = []
|
14 |
used_data = []
|
15 |
dataset = load_dataset(data_file_name)
|
|
|
37 |
continue
|
38 |
|
39 |
logging.info(f"Executing Query {idx + 1} for Model: {modelname}")
|
40 |
+
query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name, correct_rate)
|
41 |
|
42 |
# Retry mechanism for prediction
|
43 |
for attempt in range(1, config['retry_attempts'] + 1):
|
scripts/get_scores.py
CHANGED
@@ -65,7 +65,7 @@ def load_negative_rejection_scores(config):
|
|
65 |
return pd.DataFrame()
|
66 |
|
67 |
if not os.path.exists(Negative_Rejection_DIR):
|
68 |
-
return pd.DataFrame(columns=["Model", "Rejection Rate"])
|
69 |
|
70 |
score_data = {}
|
71 |
|
@@ -80,44 +80,47 @@ def load_negative_rejection_scores(config):
|
|
80 |
with open(filepath, "r") as f:
|
81 |
score = json.load(f)
|
82 |
reject_rate = score.get("reject_rate", "N/A")
|
83 |
-
score_data[model] = f"{reject_rate * 100
|
84 |
else:
|
85 |
score_data[model] = "N/A"
|
86 |
|
87 |
# Convert to DataFrame
|
88 |
df = pd.DataFrame([
|
89 |
-
{"Model": model, "Rejection Rate": score_data[model]}
|
90 |
for model in config["models"]
|
91 |
])
|
92 |
|
93 |
return df
|
94 |
|
95 |
def load_counterfactual_robustness_scores(config):
|
96 |
-
|
97 |
-
config['noise_rate'] = 0.4
|
|
|
98 |
if not os.path.exists(Counterfactual_Robustness_DIR):
|
99 |
-
return pd.DataFrame(columns=["Model", "Accuracy
|
100 |
|
101 |
score_data = {}
|
102 |
|
103 |
-
# Iterate over each model in config['models']
|
104 |
for model in config["models"]:
|
105 |
-
# Expected filename pattern for each model
|
106 |
expected_filename = f"scores_{model}_noise_{config['noise_rate']}_passage_{config['passage_num']}_num_queries_{config['num_queries']}.json"
|
107 |
filepath = os.path.join(Counterfactual_Robustness_DIR, expected_filename)
|
108 |
|
109 |
-
# Check if file exists
|
110 |
if os.path.exists(filepath):
|
111 |
with open(filepath, "r") as f:
|
112 |
-
|
|
|
|
|
|
|
113 |
score_data[model] = {
|
114 |
-
"Accuracy
|
115 |
-
"
|
116 |
-
"
|
|
|
117 |
}
|
118 |
else:
|
119 |
-
score_data[model] = {
|
120 |
-
"Accuracy
|
|
|
121 |
"Error Detection Rate (%)": "N/A",
|
122 |
"Correction Rate (%)": "N/A"
|
123 |
}
|
@@ -126,8 +129,9 @@ def load_counterfactual_robustness_scores(config):
|
|
126 |
df = pd.DataFrame([
|
127 |
{
|
128 |
"Model": model,
|
129 |
-
"Accuracy
|
130 |
-
"
|
|
|
131 |
"Correction Rate (%)": f"{score_data[model]['Correction Rate (%)']:.2f}" if score_data[model]["Correction Rate (%)"] != "N/A" else "N/A"
|
132 |
}
|
133 |
for model in config["models"]
|
|
|
65 |
return pd.DataFrame()
|
66 |
|
67 |
if not os.path.exists(Negative_Rejection_DIR):
|
68 |
+
return pd.DataFrame(columns=["Model", "Rejection Rate %"])
|
69 |
|
70 |
score_data = {}
|
71 |
|
|
|
80 |
with open(filepath, "r") as f:
|
81 |
score = json.load(f)
|
82 |
reject_rate = score.get("reject_rate", "N/A")
|
83 |
+
score_data[model] = f"{reject_rate * 100}" if reject_rate != "N/A" else "N/A"
|
84 |
else:
|
85 |
score_data[model] = "N/A"
|
86 |
|
87 |
# Convert to DataFrame
|
88 |
df = pd.DataFrame([
|
89 |
+
{"Model": model, "Rejection Rate %": score_data[model]}
|
90 |
for model in config["models"]
|
91 |
])
|
92 |
|
93 |
return df
|
94 |
|
95 |
def load_counterfactual_robustness_scores(config):
|
96 |
+
"""Load and format counterfactual robustness scores into a table with proper formatting."""
|
97 |
+
config['noise_rate'] = 0.4 # Hardcode noise rate
|
98 |
+
|
99 |
if not os.path.exists(Counterfactual_Robustness_DIR):
|
100 |
+
return pd.DataFrame(columns=["Model", "Accuracy (%)", "Acc_doc (%)", "Error Detection Rate (%)", "Correction Rate (%)"])
|
101 |
|
102 |
score_data = {}
|
103 |
|
|
|
104 |
for model in config["models"]:
|
|
|
105 |
expected_filename = f"scores_{model}_noise_{config['noise_rate']}_passage_{config['passage_num']}_num_queries_{config['num_queries']}.json"
|
106 |
filepath = os.path.join(Counterfactual_Robustness_DIR, expected_filename)
|
107 |
|
|
|
108 |
if os.path.exists(filepath):
|
109 |
with open(filepath, "r") as f:
|
110 |
+
scores_json = json.load(f) # Read the full JSON content
|
111 |
+
factual_score = next((s for s in scores_json["conditions"] if s["condition_label"] == "factual_only"), {})
|
112 |
+
counterfactual_score = next((s for s in scores_json["conditions"] if s["condition_label"] == "counterfactual"), {})
|
113 |
+
|
114 |
score_data[model] = {
|
115 |
+
"Accuracy (%)": int(round(factual_score.get("all_rate", 0) * 100)) if factual_score else "N/A",
|
116 |
+
"Acc_doc (%)": int(round(counterfactual_score.get("all_rate", 0) * 100)) if counterfactual_score else "N/A",
|
117 |
+
"Error Detection Rate (%)": int(round(counterfactual_score.get("reject_rate", 0) * 100)) if counterfactual_score else "N/A",
|
118 |
+
"Correction Rate (%)": round(counterfactual_score.get("correct_rate", 0) * 100, 2) if counterfactual_score else "N/A"
|
119 |
}
|
120 |
else:
|
121 |
+
score_data[model] = {
|
122 |
+
"Accuracy (%)": "N/A",
|
123 |
+
"Acc_doc (%)": "N/A",
|
124 |
"Error Detection Rate (%)": "N/A",
|
125 |
"Correction Rate (%)": "N/A"
|
126 |
}
|
|
|
129 |
df = pd.DataFrame([
|
130 |
{
|
131 |
"Model": model,
|
132 |
+
"Accuracy (%)": f"{score_data[model]['Accuracy (%)']}" if score_data[model]["Accuracy (%)"] != "N/A" else "N/A",
|
133 |
+
"Acc_doc (%)": f"{score_data[model]['Acc_doc (%)']}" if score_data[model]["Acc_doc (%)"] != "N/A" else "N/A",
|
134 |
+
"Error Detection Rate (%)": f"{score_data[model]['Error Detection Rate (%)']}" if score_data[model]["Error Detection Rate (%)"] != "N/A" else "N/A",
|
135 |
"Correction Rate (%)": f"{score_data[model]['Correction Rate (%)']:.2f}" if score_data[model]["Correction Rate (%)"] != "N/A" else "N/A"
|
136 |
}
|
137 |
for model in config["models"]
|
scripts/process_data.py
CHANGED
@@ -36,15 +36,39 @@ def process_data(instance, noise_rate, passage_num, filename, correct_rate=0):
|
|
36 |
# Handling the '_fact' case in filename
|
37 |
elif '_fact' in filename:
|
38 |
correct_num = math.ceil(passage_num * correct_rate)
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
# Default case (when filename doesn't match '_int' or '_fact')
|
50 |
else:
|
|
|
36 |
# Handling the '_fact' case in filename
|
37 |
elif '_fact' in filename:
|
38 |
correct_num = math.ceil(passage_num * correct_rate)
|
39 |
+
# Adjust correct_num to not exceed passage_num - neg_num, excluding positive_wrong
|
40 |
+
if correct_rate == 1.0:
|
41 |
+
# For factual-only with noise, use only positive and negative documents
|
42 |
+
correct_num = min(correct_num, passage_num - neg_num)
|
43 |
+
pos_num = 0 # No positive_wrong documents when correct_rate = 1.0
|
44 |
+
else:
|
45 |
+
# For other correct_rate values, calculate pos_num for positive_wrong
|
46 |
+
pos_num = passage_num - neg_num - correct_num
|
47 |
+
if pos_num < 0:
|
48 |
+
pos_num = 0 # Ensure pos_num is not negative
|
49 |
+
|
50 |
+
# Select positive documents (factual) first
|
51 |
+
indexs_positive = list(range(len(instance['positive'])))
|
52 |
+
selected_positive = random.sample(indexs_positive, min(len(indexs_positive), correct_num))
|
53 |
+
docs = [instance['positive'][i] for i in selected_positive]
|
54 |
+
|
55 |
+
# Add negative documents (noise) if needed
|
56 |
+
if neg_num > 0 and 'negative' in instance:
|
57 |
+
docs += instance['negative'][:min(neg_num, len(instance['negative']))]
|
58 |
+
|
59 |
+
# Only add positive_wrong documents if pos_num > 0 and correct_rate < 1.0
|
60 |
+
if pos_num > 0 and correct_rate < 1.0:
|
61 |
+
indexs_positive_wrong = list(range(len(instance['positive_wrong'])))
|
62 |
+
selected_positive_wrong = random.sample(indexs_positive_wrong, min(len(indexs_positive_wrong), pos_num))
|
63 |
+
docs += [instance['positive_wrong'][i] for i in selected_positive_wrong]
|
64 |
+
|
65 |
+
# Ensure docs length does not exceed passage_num
|
66 |
+
if len(docs) > passage_num:
|
67 |
+
random.shuffle(docs)
|
68 |
+
docs = docs[:passage_num]
|
69 |
+
elif len(docs) < passage_num and 'negative' in instance:
|
70 |
+
remaining = passage_num - len(docs)
|
71 |
+
docs += instance['negative'][:min(remaining, len(instance['negative']))]
|
72 |
|
73 |
# Default case (when filename doesn't match '_int' or '_fact')
|
74 |
else:
|