Upload 6 files
Browse files
scripts/evaluate_factual_robustness.py
CHANGED
@@ -14,7 +14,7 @@ def evaluate_factual_robustness(config):
|
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
|
17 |
-
if config['model_name'] in config[
|
18 |
model = GroqClient(plm=config['model_name'])
|
19 |
else:
|
20 |
logging.warning(f"Skipping unknown model: {config['model_name']}")
|
|
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
|
17 |
+
if config['model_name'] in config['models']:
|
18 |
model = GroqClient(plm=config['model_name'])
|
19 |
else:
|
20 |
logging.warning(f"Skipping unknown model: {config['model_name']}")
|
scripts/evaluate_information_integration.py
CHANGED
@@ -10,12 +10,13 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
10 |
|
11 |
# Improved function to evaluate noise robustness
|
12 |
def evaluate_information_integration(config):
|
13 |
-
result_path = config[
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
|
|
16 |
|
17 |
# Iterate over each model specified in the config
|
18 |
-
filename = os.path.join(result_path, f'prediction_{
|
19 |
ensure_directory_exists(filename)
|
20 |
|
21 |
# Load existing results if file exists
|
@@ -28,7 +29,7 @@ def evaluate_information_integration(config):
|
|
28 |
data = json.loads(line)
|
29 |
useddata[data['id']] = data'''
|
30 |
|
31 |
-
results = get_prediction_result(config, config[
|
32 |
|
33 |
# Save results to a file
|
34 |
with open(filename, 'w', encoding='utf-8') as f:
|
@@ -45,7 +46,7 @@ def evaluate_information_integration(config):
|
|
45 |
|
46 |
# Save the final score file with tt and all_rate
|
47 |
scores = {
|
48 |
-
'model':
|
49 |
'accuracy': accuracy,
|
50 |
'noise_rate': noise_rate,
|
51 |
'correct_count': correct_count,
|
@@ -56,7 +57,7 @@ def evaluate_information_integration(config):
|
|
56 |
logging.info(f"Score: {scores}")
|
57 |
logging.info(f"Information Integration Accuracy: {accuracy:.2%}")
|
58 |
|
59 |
-
score_filename = os.path.join(result_path, f'scores_{
|
60 |
with open(score_filename, 'w') as f:
|
61 |
json.dump(scores, f, ensure_ascii=False, indent=4)
|
62 |
|
|
|
10 |
|
11 |
# Improved function to evaluate noise robustness
|
12 |
def evaluate_information_integration(config):
|
13 |
+
result_path = config['result_path'] + 'Information Integration/'
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
+
model_name = config['model_name']
|
17 |
|
18 |
# Iterate over each model specified in the config
|
19 |
+
filename = os.path.join(result_path, f'prediction_{model_name}_noise_{noise_rate}_passage_{passage_num}.json')
|
20 |
ensure_directory_exists(filename)
|
21 |
|
22 |
# Load existing results if file exists
|
|
|
29 |
data = json.loads(line)
|
30 |
useddata[data['id']] = data'''
|
31 |
|
32 |
+
results = get_prediction_result(config, config['integration_file_name']) # Store results for this model
|
33 |
|
34 |
# Save results to a file
|
35 |
with open(filename, 'w', encoding='utf-8') as f:
|
|
|
46 |
|
47 |
# Save the final score file with tt and all_rate
|
48 |
scores = {
|
49 |
+
'model': model_name,
|
50 |
'accuracy': accuracy,
|
51 |
'noise_rate': noise_rate,
|
52 |
'correct_count': correct_count,
|
|
|
57 |
logging.info(f"Score: {scores}")
|
58 |
logging.info(f"Information Integration Accuracy: {accuracy:.2%}")
|
59 |
|
60 |
+
score_filename = os.path.join(result_path, f'scores_{model_name}_noise_{noise_rate}_passage_{passage_num}.json')
|
61 |
with open(score_filename, 'w') as f:
|
62 |
json.dump(scores, f, ensure_ascii=False, indent=4)
|
63 |
|
scripts/evaluate_negative_rejection.py
CHANGED
@@ -15,7 +15,7 @@ def evaluate_negative_rejection(config):
|
|
15 |
noise_rate = config['noise_rate']
|
16 |
passage_num = config['passage_num']
|
17 |
|
18 |
-
if config['model_name'] in config[
|
19 |
model = GroqClient(plm=config['model_name'])
|
20 |
else:
|
21 |
logging.warning(f"Skipping unknown model: {config['model_name']}")
|
|
|
15 |
noise_rate = config['noise_rate']
|
16 |
passage_num = config['passage_num']
|
17 |
|
18 |
+
if config['model_name'] in config['models']:
|
19 |
model = GroqClient(plm=config['model_name'])
|
20 |
else:
|
21 |
logging.warning(f"Skipping unknown model: {config['model_name']}")
|
scripts/evaluate_noise_robustness.py
CHANGED
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
10 |
|
11 |
# Improved function to evaluate noise robustness
|
12 |
def evaluate_noise_robustness(config):
|
13 |
-
result_path = config[
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
|
@@ -28,7 +28,7 @@ def evaluate_noise_robustness(config):
|
|
28 |
data = json.loads(line)
|
29 |
useddata[data['id']] = data'''
|
30 |
|
31 |
-
results = get_prediction_result(config, config[
|
32 |
|
33 |
# Save results to a file
|
34 |
with open(filename, 'w', encoding='utf-8') as f:
|
|
|
10 |
|
11 |
# Improved function to evaluate noise robustness
|
12 |
def evaluate_noise_robustness(config):
|
13 |
+
result_path = config['result_path'] + 'Noise Robustness/'
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
|
|
|
28 |
data = json.loads(line)
|
29 |
useddata[data['id']] = data'''
|
30 |
|
31 |
+
results = get_prediction_result(config, config['robustness_file_name']) # Store results for this model
|
32 |
|
33 |
# Save results to a file
|
34 |
with open(filename, 'w', encoding='utf-8') as f:
|
scripts/get_factual_evaluation.py
CHANGED
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
10 |
|
11 |
# Improved function to evaluate noise robustness
|
12 |
def get_factual_evaluation(config):
|
13 |
-
result_path = config[
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
|
@@ -28,7 +28,7 @@ def get_factual_evaluation(config):
|
|
28 |
data = json.loads(line)
|
29 |
useddata[data['id']] = data'''
|
30 |
|
31 |
-
results = get_prediction_result(config, config[
|
32 |
|
33 |
# Save results to a file
|
34 |
with open(filename, 'w', encoding='utf-8') as f:
|
|
|
10 |
|
11 |
# Improved function to evaluate noise robustness
|
12 |
def get_factual_evaluation(config):
|
13 |
+
result_path = config['result_path'] + 'Counterfactual Robustness/'
|
14 |
noise_rate = config['noise_rate']
|
15 |
passage_num = config['passage_num']
|
16 |
|
|
|
28 |
data = json.loads(line)
|
29 |
useddata[data['id']] = data'''
|
30 |
|
31 |
+
results = get_prediction_result(config, config['factual_file_name']) # Store results for this model
|
32 |
|
33 |
# Save results to a file
|
34 |
with open(filename, 'w', encoding='utf-8') as f:
|
scripts/get_prediction_result.py
CHANGED
@@ -13,7 +13,7 @@ def get_prediction_result(config, data_file_name):
|
|
13 |
results = []
|
14 |
dataset = load_dataset(data_file_name)
|
15 |
# Create GroqClient instance for supported models
|
16 |
-
if config['model_name'] in config[
|
17 |
model = GroqClient(plm=config['model_name'])
|
18 |
else:
|
19 |
logging.warning(f"Skipping unknown model: {config['model_name']}")
|
@@ -26,7 +26,7 @@ def get_prediction_result(config, data_file_name):
|
|
26 |
query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name)
|
27 |
|
28 |
# Retry mechanism for prediction
|
29 |
-
for attempt in range(1, config[
|
30 |
label, prediction, factlabel = predict(query, ans, docs, model, "Document:\n{DOCS} \n\nQuestion:\n{QUERY}", 0.7)
|
31 |
if prediction: # If response is not empty, break retry loop
|
32 |
break
|
|
|
13 |
results = []
|
14 |
dataset = load_dataset(data_file_name)
|
15 |
# Create GroqClient instance for supported models
|
16 |
+
if config['model_name'] in config['models']:
|
17 |
model = GroqClient(plm=config['model_name'])
|
18 |
else:
|
19 |
logging.warning(f"Skipping unknown model: {config['model_name']}")
|
|
|
26 |
query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name)
|
27 |
|
28 |
# Retry mechanism for prediction
|
29 |
+
for attempt in range(1, config['retry_attempts'] + 1):
|
30 |
label, prediction, factlabel = predict(query, ans, docs, model, "Document:\n{DOCS} \n\nQuestion:\n{QUERY}", 0.7)
|
31 |
if prediction: # If response is not empty, break retry loop
|
32 |
break
|