gourisankar85 commited on
Commit
d48cd84
·
verified ·
1 Parent(s): 5b260bd

Upload 6 files

Browse files
scripts/evaluate_factual_robustness.py CHANGED
@@ -14,7 +14,7 @@ def evaluate_factual_robustness(config):
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
 
17
- if config['model_name'] in config["models"]:
18
  model = GroqClient(plm=config['model_name'])
19
  else:
20
  logging.warning(f"Skipping unknown model: {config['model_name']}")
 
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
 
17
+ if config['model_name'] in config['models']:
18
  model = GroqClient(plm=config['model_name'])
19
  else:
20
  logging.warning(f"Skipping unknown model: {config['model_name']}")
scripts/evaluate_information_integration.py CHANGED
@@ -10,12 +10,13 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
10
 
11
  # Improved function to evaluate noise robustness
12
  def evaluate_information_integration(config):
13
- result_path = config["result_path"] + 'Information Integration/'
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
 
16
 
17
  # Iterate over each model specified in the config
18
- filename = os.path.join(result_path, f'prediction_{config['model_name']}_noise_{noise_rate}_passage_{passage_num}.json')
19
  ensure_directory_exists(filename)
20
 
21
  # Load existing results if file exists
@@ -28,7 +29,7 @@ def evaluate_information_integration(config):
28
  data = json.loads(line)
29
  useddata[data['id']] = data'''
30
 
31
- results = get_prediction_result(config, config["integration_file_name"]) # Store results for this model
32
 
33
  # Save results to a file
34
  with open(filename, 'w', encoding='utf-8') as f:
@@ -45,7 +46,7 @@ def evaluate_information_integration(config):
45
 
46
  # Save the final score file with tt and all_rate
47
  scores = {
48
- 'model': config['model_name'],
49
  'accuracy': accuracy,
50
  'noise_rate': noise_rate,
51
  'correct_count': correct_count,
@@ -56,7 +57,7 @@ def evaluate_information_integration(config):
56
  logging.info(f"Score: {scores}")
57
  logging.info(f"Information Integration Accuracy: {accuracy:.2%}")
58
 
59
- score_filename = os.path.join(result_path, f'scores_{config['model_name']}_noise_{noise_rate}_passage_{passage_num}.json')
60
  with open(score_filename, 'w') as f:
61
  json.dump(scores, f, ensure_ascii=False, indent=4)
62
 
 
10
 
11
  # Improved function to evaluate noise robustness
12
  def evaluate_information_integration(config):
13
+ result_path = config['result_path'] + 'Information Integration/'
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
+ model_name = config['model_name']
17
 
18
  # Iterate over each model specified in the config
19
+ filename = os.path.join(result_path, f'prediction_{model_name}_noise_{noise_rate}_passage_{passage_num}.json')
20
  ensure_directory_exists(filename)
21
 
22
  # Load existing results if file exists
 
29
  data = json.loads(line)
30
  useddata[data['id']] = data'''
31
 
32
+ results = get_prediction_result(config, config['integration_file_name']) # Store results for this model
33
 
34
  # Save results to a file
35
  with open(filename, 'w', encoding='utf-8') as f:
 
46
 
47
  # Save the final score file with tt and all_rate
48
  scores = {
49
+ 'model': model_name,
50
  'accuracy': accuracy,
51
  'noise_rate': noise_rate,
52
  'correct_count': correct_count,
 
57
  logging.info(f"Score: {scores}")
58
  logging.info(f"Information Integration Accuracy: {accuracy:.2%}")
59
 
60
+ score_filename = os.path.join(result_path, f'scores_{model_name}_noise_{noise_rate}_passage_{passage_num}.json')
61
  with open(score_filename, 'w') as f:
62
  json.dump(scores, f, ensure_ascii=False, indent=4)
63
 
scripts/evaluate_negative_rejection.py CHANGED
@@ -15,7 +15,7 @@ def evaluate_negative_rejection(config):
15
  noise_rate = config['noise_rate']
16
  passage_num = config['passage_num']
17
 
18
- if config['model_name'] in config["models"]:
19
  model = GroqClient(plm=config['model_name'])
20
  else:
21
  logging.warning(f"Skipping unknown model: {config['model_name']}")
 
15
  noise_rate = config['noise_rate']
16
  passage_num = config['passage_num']
17
 
18
+ if config['model_name'] in config['models']:
19
  model = GroqClient(plm=config['model_name'])
20
  else:
21
  logging.warning(f"Skipping unknown model: {config['model_name']}")
scripts/evaluate_noise_robustness.py CHANGED
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
10
 
11
  # Improved function to evaluate noise robustness
12
  def evaluate_noise_robustness(config):
13
- result_path = config["result_path"] + 'Noise Robustness/'
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
 
@@ -28,7 +28,7 @@ def evaluate_noise_robustness(config):
28
  data = json.loads(line)
29
  useddata[data['id']] = data'''
30
 
31
- results = get_prediction_result(config, config["robustness_file_name"]) # Store results for this model
32
 
33
  # Save results to a file
34
  with open(filename, 'w', encoding='utf-8') as f:
 
10
 
11
  # Improved function to evaluate noise robustness
12
  def evaluate_noise_robustness(config):
13
+ result_path = config['result_path'] + 'Noise Robustness/'
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
 
 
28
  data = json.loads(line)
29
  useddata[data['id']] = data'''
30
 
31
+ results = get_prediction_result(config, config['robustness_file_name']) # Store results for this model
32
 
33
  # Save results to a file
34
  with open(filename, 'w', encoding='utf-8') as f:
scripts/get_factual_evaluation.py CHANGED
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
10
 
11
  # Improved function to evaluate noise robustness
12
  def get_factual_evaluation(config):
13
- result_path = config["result_path"] + 'Counterfactual Robustness/'
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
 
@@ -28,7 +28,7 @@ def get_factual_evaluation(config):
28
  data = json.loads(line)
29
  useddata[data['id']] = data'''
30
 
31
- results = get_prediction_result(config, config["factual_file_name"]) # Store results for this model
32
 
33
  # Save results to a file
34
  with open(filename, 'w', encoding='utf-8') as f:
 
10
 
11
  # Improved function to evaluate noise robustness
12
  def get_factual_evaluation(config):
13
+ result_path = config['result_path'] + 'Counterfactual Robustness/'
14
  noise_rate = config['noise_rate']
15
  passage_num = config['passage_num']
16
 
 
28
  data = json.loads(line)
29
  useddata[data['id']] = data'''
30
 
31
+ results = get_prediction_result(config, config['factual_file_name']) # Store results for this model
32
 
33
  # Save results to a file
34
  with open(filename, 'w', encoding='utf-8') as f:
scripts/get_prediction_result.py CHANGED
@@ -13,7 +13,7 @@ def get_prediction_result(config, data_file_name):
13
  results = []
14
  dataset = load_dataset(data_file_name)
15
  # Create GroqClient instance for supported models
16
- if config['model_name'] in config["models"]:
17
  model = GroqClient(plm=config['model_name'])
18
  else:
19
  logging.warning(f"Skipping unknown model: {config['model_name']}")
@@ -26,7 +26,7 @@ def get_prediction_result(config, data_file_name):
26
  query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name)
27
 
28
  # Retry mechanism for prediction
29
- for attempt in range(1, config["retry_attempts"] + 1):
30
  label, prediction, factlabel = predict(query, ans, docs, model, "Document:\n{DOCS} \n\nQuestion:\n{QUERY}", 0.7)
31
  if prediction: # If response is not empty, break retry loop
32
  break
 
13
  results = []
14
  dataset = load_dataset(data_file_name)
15
  # Create GroqClient instance for supported models
16
+ if config['model_name'] in config['models']:
17
  model = GroqClient(plm=config['model_name'])
18
  else:
19
  logging.warning(f"Skipping unknown model: {config['model_name']}")
 
26
  query, ans, docs = process_data(instance, config['noise_rate'], config['passage_num'], data_file_name)
27
 
28
  # Retry mechanism for prediction
29
+ for attempt in range(1, config['retry_attempts'] + 1):
30
  label, prediction, factlabel = predict(query, ans, docs, model, "Document:\n{DOCS} \n\nQuestion:\n{QUERY}", 0.7)
31
  if prediction: # If response is not empty, break retry loop
32
  break