import os import time import json import math import copy import collections from typing import Optional, List, Dict, Tuple, Callable, Any, Union, NewType import numpy as np from tqdm import tqdm import datasets from transformers import AutoTokenizer from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from transformers.utils import logging from transformers.trainer_utils import EvalPrediction, EvalLoopOutput from .args import ( HfArgumentParser, RetroArguments, TrainingArguments, ) from .base import BaseReader from . import constants as C from .preprocess import ( get_sketch_features, get_intensive_features ) from .metrics import ( compute_classification_metric, compute_squad_v2 ) DataClassType = NewType("DataClassType", Any) logger = logging.get_logger(__name__) class SketchReader(BaseReader): name: str = "sketch" def postprocess( self, output: Union[np.ndarray, EvalLoopOutput], eval_examples: datasets.Dataset, eval_dataset: datasets.Dataset, mode: str = "evaluate", ) -> Union[EvalPrediction, Dict[str, float]]: """ Postprocess the output of the SketchReader model. Args: output (Union[np.ndarray, EvalLoopOutput]): The model output. eval_examples (datasets.Dataset): The evaluation examples. eval_dataset (datasets.Dataset): The evaluation dataset. mode (str, optional): The mode of operation. Defaults to "evaluate". Returns: Union[EvalPrediction, Dict[str, float]]: The evaluation prediction or the final map. """ # External Front Verification (E-FV) # Extract the logits from the output if isinstance(output, EvalLoopOutput): logits = output.predictions else: logits = output # Create a mapping from example ID to index example_id_to_index = {k: i for i, k in enumerate(eval_examples[C.ID_COLUMN_NAME])} # Create a mapping from example index to features features_per_example = collections.defaultdict(list) for i, feature in enumerate(eval_dataset): features_per_example[example_id_to_index[feature["example_id"]]].append(i) # example_id added from get_sketch_features # Create a mapping from example index to the number of features count_map = {k: len(v) for k, v in features_per_example.items()} # Calculate the average logits for each example logits_ans = np.zeros(len(count_map)) logits_na = np.zeros(len(count_map)) for example_index, example in enumerate(tqdm(eval_examples)): feature_indices = features_per_example[example_index] n_strides = count_map[example_index] logits_ans[example_index] += logits[example_index, 0] / n_strides logits_na[example_index] += logits[example_index, 1] / n_strides # Calculate the E-VF score score_ext = logits_ans - logits_na # Save the EVF score final_map = dict(zip(eval_examples[C.ID_COLUMN_NAME], score_ext.tolist())) with open(os.path.join(self.args.output_dir, C.SCORE_EXT_FILE_NAME), "w") as writer: writer.write(json.dumps(final_map, indent=4) + "\n") if mode == "evaluate": return EvalPrediction( predictions=logits, label_ids=output.label_ids, ) else: return final_map class IntensiveReader(BaseReader): name: str = "intensive" def postprocess( self, output: EvalLoopOutput, eval_examples: datasets.Dataset, eval_dataset: datasets.Dataset, log_level: int = logging.WARNING, mode: str = "evaluate", ) -> Union[List[Dict[str, Any]], EvalPrediction]: """ Post-processing step for the internal front verification (I-FV) and formatting the results. Args: output (EvalLoopOutput): The output of the model's evaluation loop. eval_examples (datasets.Dataset): The evaluation examples. eval_dataset (datasets.Dataset): The evaluation dataset. log_level (int, optional): The logging level. Defaults to logging.WARNING. mode (str, optional): The mode of the post-processing. Defaults to "evaluate". Returns: Union[List[Dict[str, Any]], EvalPrediction]: The formatted predictions or the evaluation prediction. """ # Compute predictions predictions, nbest_json, scores_diff_json = self.compute_predictions( eval_examples, eval_dataset, output.predictions, version_2_with_negative=self.data_args.version_2_with_negative, n_best_size=self.data_args.n_best_size, max_answer_length=self.data_args.max_answer_length, null_score_diff_threshold=self.data_args.null_score_diff_threshold, output_dir=self.args.output_dir, log_level=log_level, n_tops=(self.data_args.start_n_top, self.data_args.end_n_top), ) # Return the nbest_json and scores_diff_json if in retro_inference mode if mode == "retro_inference": return nbest_json, scores_diff_json # Format the predictions if self.data_args.version_2_with_negative: formatted_predictions = [ { "id": k, "prediction_text": v, "no_answer_probability": scores_diff_json[k], } for k, v in predictions.items() ] else: formatted_predictions = [ {"id": k, "prediction_text": v} for k, v in predictions.items() ] # Return the formatted predictions if in predict mode if mode == "predict": return formatted_predictions # Format the evaluation predictions references = [ {"id": ex[C.ID_COLUMN_NAME], "answers": ex[C.ANSWER_COLUMN_NAME]} for ex in eval_examples ] return EvalPrediction( predictions=formatted_predictions, label_ids=references ) def compute_predictions( self, examples: datasets.Dataset, features: datasets.Dataset, predictions: Tuple[np.ndarray, np.ndarray], version_2_with_negative: bool = False, n_best_size: int = 20, max_answer_length: int = 30, null_score_diff_threshold: float = 0.0, output_dir: Optional[str] = None, log_level: Optional[int] = logging.WARNING, n_tops: Tuple[int, int] = (-1, -1), use_choice_logits: bool = False, ): """ Compute predictions for a given set of examples based on the provided features and model predictions. Args: examples (datasets.Dataset): The dataset containing the examples. features (datasets.Dataset): The dataset containing the features. predictions (Tuple[np.ndarray, np.ndarray]): A tuple containing the start logits, end logits, and choice logits. version_2_with_negative (bool, optional): Whether to use version 2 with negative predictions. Defaults to False. n_best_size (int, optional): The number of top predictions to consider. Defaults to 20. max_answer_length (int, optional): The maximum length of the answer. Defaults to 30. null_score_diff_threshold (float, optional): The score difference threshold for the null prediction. Defaults to 0.0. output_dir (Optional[str], optional): The directory to save the predictions. Defaults to None. log_level (Optional[int], optional): The log level. Defaults to logging.WARNING. n_tops (Tuple[int, int], optional): The number of top predictions to consider for each example. Defaults to (-1, -1). use_choice_logits (bool, optional): Whether to use choice logits. Defaults to False. Returns: Tuple[Dict[str, str], Dict[str, List[Dict[str, Union[str, float]]]], Dict[str, float]]: A tuple containing the all predictions, all n-best predictions, and scores difference. Raises: ValueError: If the length of predictions is not 2 or 3. """ if len(predictions) not in [2, 3]: raise ValueError( "`predictions` should be a tuple with two elements (start_logits, end_logits) or three elements (start_logits, end_logits, choice_logits)." ) # if len(predictions) == 3: # all_start_logits, all_end_logits, all_choice_logits = predictions # else: # all_start_logits, all_end_logits = predictions # all_choice_logits = None all_start_logits, all_end_logits = predictions[:2] all_choice_logits = None if len(predictions) == 3: all_choice_logits = predictions[-1] # all_choice_logits = predictions[2] if len(predictions) == 3 else None # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples[C.ID_COLUMN_NAME])} features_per_example = collections.defaultdict(list) for i, feature in enumerate(features): features_per_example[example_id_to_index[feature["example_id"]]].append(i) all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() scores_diff_json = collections.OrderedDict() if version_2_with_negative else None # Logging. logger.setLevel(log_level) logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Looping through all the examples for example_index, example in enumerate(tqdm(examples)): # Those are the indices of the features associated to the current example. feature_indices = features_per_example[example_index] min_null_prediction = None prelim_predictions = [] # Looping through all the features associated to the current example. for feature_index in feature_indices: # We grab the predictions of the model for this feature. start_logits = all_start_logits[feature_index] end_logits = all_end_logits[feature_index] feature_null_score = start_logits[0] + end_logits[0] if all_choice_logits is not None: choice_logits = all_choice_logits[feature_index] if use_choice_logits: feature_null_score = choice_logits[1] # This is what will allow us to map some the positions # in our logits to span of texts in the original context. offset_mapping = features[feature_index]["offset_mapping"] # Optional `token_is_max_context`, # if provided we will remove answers that do not have the maximum context # available in the current feature. token_is_max_context = features[feature_index].get("token_is_max_context", None) # Update minimum null prediction if ( min_null_prediction is None or min_null_prediction["score"] > feature_null_score ): min_null_prediction = { "offsets": (0, 0), "score": feature_null_score, "start_logit": start_logits[0], "end_logit": end_logits[0], } # Go through all possibilities for the {top k} greater start and end logits start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, e.g., predict # that the start of the span is in the question. We throw out all # invalid predictions. if ( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None or offset_mapping[end_index] is None ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. if ( end_index < start_index or end_index - start_index + 1 > max_answer_length ): continue # Don't consider answer that don't have the maximum context available if ( token_is_max_context is not None and not token_is_max_context.get(str(start_index), False) ): continue prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), "score": start_logits[start_index] + end_logits[end_index], "start_logit": start_logits[start_index], "end_logit": end_logits[end_index], } ) if version_2_with_negative: # Add the minimum null prediction prelim_predictions.append(min_null_prediction) null_score = min_null_prediction["score"] # Only keep the best `n_best_size` predictions. predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] # Add back the minimum null prediction if it was removed because of its low score if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): predictions.append(min_null_prediction) # Use the offsets to gather the answer text in the original context context = example["context"] for pred in predictions: offsets = pred.pop("offsets") pred["text"] = context[offsets[0] : offsets[1]] # In the very rare edge case we have not a single non-null prediction, # we create a fake prediction to avoid failure. if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): predictions.insert(0, {"text": "", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0,}) # Compute the softmax of all scores # (we do it with numpy to stay independent from torch/tf) in this file, # using the LogSum trick). scores = np.array([pred.pop("score") for pred in predictions]) exp_scores = np.exp(scores - np.max(scores)) probs = exp_scores / exp_scores.sum() # Include the probabilities in our predictions. for prob, pred in zip(probs, predictions): pred["probability"] = prob # Pick the best prediction. If the null answer is not possible, this is easy. if not version_2_with_negative: all_predictions[example[C.ID_COLUMN_NAME]] = predictions[0]["text"] else: # Otherwise we first need to find the best non-empty prediction. i = 0 try: while predictions[i]["text"] == "": i += 1 except: i = 0 best_non_null_pred = predictions[i] # Then we compare to the null prediction using the threshold. score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] scores_diff_json[example[C.ID_COLUMN_NAME]] = float(score_diff) # To be JSON-serializable. if score_diff > null_score_diff_threshold: all_predictions[example[C.ID_COLUMN_NAME]] = "" else: all_predictions[example[C.ID_COLUMN_NAME]] = best_non_null_pred["text"] # Make `predictions` JSON-serializable by casting np.float back to float. all_nbest_json[example[C.ID_COLUMN_NAME]] = [ {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} for pred in predictions ] # If we have an output_dir, let's save all those dicts. if output_dir is not None: if not os.path.isdir(output_dir): raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join(output_dir, C.INTENSIVE_PRED_FILE_NAME) nbest_file = os.path.join(output_dir, C.NBEST_PRED_FILE_NAME) if version_2_with_negative: null_odds_file = os.path.join(output_dir, C.SCORE_DIFF_FILE_NAME) logger.info(f"Saving predictions to {prediction_file}.") with open(prediction_file, "w") as writer: writer.write(json.dumps(all_predictions, indent=4) + "\n") logger.info(f"Saving nbest_preds to {nbest_file}.") with open(nbest_file, "w") as writer: writer.write(json.dumps(all_nbest_json, indent=4) + "\n") if version_2_with_negative: logger.info(f"Saving null_odds to {null_odds_file}.") with open(null_odds_file, "w") as writer: writer.write(json.dumps(scores_diff_json, indent=4) + "\n") return all_predictions, all_nbest_json, scores_diff_json class RearVerifier: def __init__( self, beta1: int = 1, beta2: int = 1, best_cof: int = 1, thresh: float = 0.0, ): self.beta1 = beta1 self.beta2 = beta2 self.best_cof = best_cof self.thresh = thresh def __call__( self, score_ext: Dict[str, float], score_diff: Dict[str, float], nbest_preds: Dict[str, Dict[int, Dict[str, float]]] ): """ This function takes in the score_ext and score_diff dictionaries, and the nbest_preds dictionary. It performs a verification process on the input data and returns the output predictions and scores. Args: score_ext (Dict[str, float]): A dictionary containing the extended scores. score_diff (Dict[str, float]): A dictionary containing the score differences. nbest_preds (Dict[str, Dict[int, Dict[str, float]]]): A dictionary containing the nbest predictions. Returns: Tuple[Dict[str, str], Dict[str, float]]: A tuple containing the output predictions and scores. """ # Initialize an ordered dictionary to store all the scores all_scores = collections.OrderedDict() # Check if the keys of score_ext and score_diff are equal assert score_ext.keys() == score_diff.keys() # Iterate over the keys in score_ext and calculate the scores for key in score_ext.keys(): if key not in all_scores: all_scores[key] = [] all_scores[key].extend( [self.beta1 * score_ext[key], self.beta2 * score_diff[key]] ) # Calculate the mean score for each key and store it in output_scores output_scores = {} for key, scores in all_scores.items(): mean_score = sum(scores) / float(len(scores)) output_scores[key] = mean_score # Initialize an ordered dictionary to store all the nbest predictions all_nbest = collections.OrderedDict() # Iterate over the keys in nbest_preds and calculate the nbest predictions for key, entries in nbest_preds.items(): if key not in all_nbest: all_nbest[key] = collections.defaultdict(float) for entry in entries: prob = self.best_cof * entry["probability"] all_nbest[key][entry["text"]] += prob # # Sort the nbest predictions for each key based on the probability and store the best text in output_predictions # output_predictions = {key: sorted(entry_map.keys(), key=lambda x: entry_map[x], reverse=True)[0] for key, entry_map in all_nbest.items()} # # If the score for a question is above the threshold, set the prediction to empty string # output_predictions = {qid: "" if output_scores[qid] > self.thresh else output_predictions[qid] for qid in output_predictions.keys()} # Sort the nbest predictions for each key based on the probability and store the best text in output_predictions output_predictions = {} for key, entry_map in all_nbest.items(): sorted_texts = sorted( entry_map.keys(), key=lambda x: entry_map[x], reverse=True ) best_text = sorted_texts[0] output_predictions[key] = best_text # If the score for a question is above the threshold, set the prediction to empty string for qid in output_predictions.keys(): if output_scores[qid] > self.thresh: output_predictions[qid] = "" return output_predictions, output_scores class RetroReader: def __init__( self, args, sketch_reader: SketchReader, intensive_reader: IntensiveReader, rear_verifier: RearVerifier, prep_fn: Tuple[Callable, Callable], ): self.args = args # Set submodules self.sketch_reader = sketch_reader self.intensive_reader = intensive_reader self.rear_verifier = rear_verifier # Set prep function for inference self.sketch_prep_fn, self.intensive_prep_fn = prep_fn @classmethod def load( cls, train_examples=None, sketch_train_dataset=None, intensive_train_dataset=None, eval_examples=None, sketch_eval_dataset=None, intensive_eval_dataset=None, config_file: str = C.DEFAULT_CONFIG_FILE, device: str = "cpu", ): # Get arguments from yaml files parser = HfArgumentParser([RetroArguments, TrainingArguments]) retro_args, training_args = parser.parse_yaml_file(yaml_file=config_file) if training_args.run_name is not None and "," in training_args.run_name: sketch_run_name, intensive_run_name = training_args.run_name.split(",") else: sketch_run_name, intensive_run_name = None, None if training_args.metric_for_best_model is not None and "," in training_args.metric_for_best_model: sketch_best_metric, intensive_best_metric = training_args.metric_for_best_model.split(",") else: sketch_best_metric, intensive_best_metric = None, None sketch_training_args = copy.deepcopy(training_args) intensive_training_args = training_args print(f"Loading sketch tokenizer from {retro_args.sketch_tokenizer_name} ...") sketch_tokenizer = AutoTokenizer.from_pretrained( # pretrained_model_name_or_path="google/electra-large-discriminator", pretrained_model_name_or_path=retro_args.sketch_tokenizer_name, use_auth_token=retro_args.use_auth_token, revision=retro_args.sketch_revision, # return_tensors='pt', ) # sketch_tokenizer.to(device) # If `train_examples` is feeded, perform preprocessing if train_examples is not None and sketch_train_dataset is None: print("[Sketch] Preprocessing train examples ...") sketch_prep_fn, is_batched = get_sketch_features(sketch_tokenizer, "train", retro_args) sketch_train_dataset = train_examples.map( sketch_prep_fn, batched=is_batched, remove_columns=train_examples.column_names, num_proc=retro_args.preprocessing_num_workers, load_from_cache_file=not retro_args.overwrite_cache, ) # If `eval_examples` is feeded, perform preprocessing if eval_examples is not None and sketch_eval_dataset is None: print("[Sketch] Preprocessing eval examples ...") sketch_prep_fn, is_batched = get_sketch_features(sketch_tokenizer, "eval", retro_args) sketch_eval_dataset = eval_examples.map( sketch_prep_fn, batched=is_batched, remove_columns=eval_examples.column_names, num_proc=retro_args.preprocessing_num_workers, load_from_cache_file=not retro_args.overwrite_cache, ) # Get preprocessing function for inference print("[Sketch] Preprocessing inference examples ...") sketch_prep_fn, _ = get_sketch_features(sketch_tokenizer, "test", retro_args) # Get model for sketch reader sketch_model_cls = retro_args.sketch_model_cls print(f"[Sketch] Loading sketch model from {retro_args.sketch_model_name} ...") sketch_model = sketch_model_cls.from_pretrained( pretrained_model_name_or_path=retro_args.sketch_model_name, use_auth_token=retro_args.use_auth_token, revision=retro_args.sketch_revision, ) sketch_model.to(device) # # Free sketch weights for transfer learning # if retro_args.sketch_model_mode == "finetune": # pass # else: # print("[Sketch] Freezing sketch weights for transfer learning ...") # for param in list(sketch_model.parameters())[:-5]: # param.requires_grad_(False) # Get sketch reader sketch_training_args.run_name = sketch_run_name sketch_training_args.output_dir += "/sketch" sketch_training_args.metric_for_best_model = sketch_best_metric sketch_reader = SketchReader( model=sketch_model, args=sketch_training_args, train_dataset=sketch_train_dataset, eval_dataset=sketch_eval_dataset, eval_examples=eval_examples, data_args=retro_args, tokenizer=sketch_tokenizer, compute_metrics=compute_classification_metric, ) print(f"[Intensive] Loading intensive tokenizer from {retro_args.intensive_tokenizer_name} ...") intensive_tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=retro_args.intensive_tokenizer_name, use_auth_token=retro_args.use_auth_token, revision=retro_args.intensive_revision, # return_tensors='pt', ) # intensive_tokenizer.to(device) # If `train_examples` is feeded, perform preprocessing if train_examples is not None and intensive_train_dataset is None: print("[Intensive] Preprocessing train examples ...") intensive_prep_fn, is_batched = get_intensive_features(intensive_tokenizer, "train", retro_args) intensive_train_dataset = train_examples.map( intensive_prep_fn, batched=is_batched, remove_columns=train_examples.column_names, num_proc=retro_args.preprocessing_num_workers, load_from_cache_file=not retro_args.overwrite_cache, ) # If `eval_examples` is feeded, perform preprocessing if eval_examples is not None and intensive_eval_dataset is None: print("[Intensive] Preprocessing eval examples ...") intensive_prep_fn, is_batched = get_intensive_features(intensive_tokenizer, "eval", retro_args) intensive_eval_dataset = eval_examples.map( intensive_prep_fn, batched=is_batched, remove_columns=eval_examples.column_names, num_proc=retro_args.preprocessing_num_workers, load_from_cache_file=not retro_args.overwrite_cache, ) # Get preprocessing function for inference print("[Intensive] Preprocessing test examples ...") intensive_prep_fn, _ = get_intensive_features(intensive_tokenizer, "test", retro_args) # Get model for intensive reader intensive_model_cls = retro_args.intensive_model_cls print(f"[Intensive] Loading intensive model from {retro_args.intensive_model_name} ...") intensive_model = intensive_model_cls.from_pretrained( pretrained_model_name_or_path=retro_args.intensive_model_name, use_auth_token=retro_args.use_auth_token, revision=retro_args.intensive_revision, ) intensive_model.to(device) # Free intensive weights for transfer learning if retro_args.intensive_model_mode == "finetune": pass else: print("[Intensive] Freezing intensive weights for transfer learning ...") for param in list(intensive_model.parameters())[:-5]: param.requires_grad_(False) # Get intensive reader intensive_training_args.run_name = intensive_run_name intensive_training_args.output_dir += "/intensive" intensive_training_args.metric_for_best_model = intensive_best_metric intensive_reader = IntensiveReader( model=intensive_model, args=intensive_training_args, train_dataset=intensive_train_dataset, eval_dataset=intensive_eval_dataset, eval_examples=eval_examples, data_args=retro_args, tokenizer=intensive_tokenizer, compute_metrics=compute_squad_v2, ) # Get rear verifier rear_verifier = RearVerifier( beta1=retro_args.beta1, beta2=retro_args.beta2, best_cof=retro_args.best_cof, thresh=retro_args.rear_threshold, ) return cls( args=retro_args, sketch_reader=sketch_reader, intensive_reader=intensive_reader, rear_verifier=rear_verifier, prep_fn=(sketch_prep_fn, intensive_prep_fn), ) def __call__( self, query: str, context: Union[str, List[str]], return_submodule_outputs: bool = False, ) -> Tuple[Any]: """ Performs inference on a given query and context. Args: query (str): The query to be answered. context (Union[str, List[str]]): The context in which the query is asked. If it is a list of strings, they will be joined together. return_submodule_outputs (bool, optional): Whether to return the outputs of the submodules. Defaults to False. Returns: Tuple[Any]: A tuple containing the predictions, scores, and optionally the outputs of the submodules. """ # If context is a list, join it into a single string if isinstance(context, list): context = " ".join(context) # Create a predict examples dataset with a single example predict_examples = datasets.Dataset.from_dict({ "example_id": ["0"], # Example ID C.ID_COLUMN_NAME: ["id-01"], # ID C.QUESTION_COLUMN_NAME: [query], # Query C.CONTEXT_COLUMN_NAME: [context], # Context }) # Perform inference on the predict examples dataset return self.inference(predict_examples, return_submodule_outputs=return_submodule_outputs) def train(self, module: str = "all", device: str = "cpu"): """ Trains the specified module. Args: module (str, optional): The module to train. Defaults to "all". Possible values: "all", "sketch", "intensive". """ def wandb_finish(module): """ Finishes the Weights & Biases (wandb) run for the given module. Args: module: The module for which to finish the wandb run. """ for callback in module.callback_handler.callbacks: # Check if the callback is a wandb callback if "wandb" in str(type(callback)).lower(): # Finish the wandb run if hasattr(callback, '_wandb'): callback._wandb.finish() # Reset the initialized flag callback._initialized = False print(f"Starting training for module: {module}") # Train sketch reader if module.lower() in ["all", "sketch"]: print("Training sketch reader") self.sketch_reader.train() print("Saving sketch reader") self.sketch_reader.save_model() print("Saving sketch reader state") self.sketch_reader.save_state() self.sketch_reader.free_memory() wandb_finish(self.sketch_reader) print("Sketch reader training finished") # Train intensive reader if module.lower() in ["all", "intensive"]: print("Training intensive reader") self.intensive_reader.train() print("Saving intensive reader") self.intensive_reader.save_model() print("Saving intensive reader state") self.intensive_reader.save_state() self.intensive_reader.free_memory() wandb_finish(self.intensive_reader) print("Intensive reader training finished") print("Training finished") def inference(self, predict_examples: datasets.Dataset, return_submodule_outputs: bool = True) -> Tuple[Any]: """ Performs inference on the given predict examples dataset. Args: predict_examples (datasets.Dataset): The dataset containing the predict examples. return_submodule_outputs (bool, optional): Whether to return the outputs of the submodules. Defaults to False. Returns: Tuple[Any]: A tuple containing the predictions, scores, and optionally the outputs (score_ext, nbest_preds, score_diff) of the submodules. """ # Add the example_id column if it doesn't exist if "example_id" not in predict_examples.column_names: predict_examples = predict_examples.map( lambda _, i: {"example_id": str(i)}, with_indices=True, ) # Prepare the features for sketch reader and intensive reader sketch_features = predict_examples.map( self.sketch_prep_fn, batched=True, remove_columns=predict_examples.column_names, ) intensive_features = predict_examples.map( self.intensive_prep_fn, batched=True, remove_columns=predict_examples.column_names, ) # Perform inference on sketch reader # self.sketch_reader.to(self.sketch_reader.args.device) score_ext = self.sketch_reader.predict(sketch_features, predict_examples) # self.sketch_reader.to("cpu") # Perform inference on intensive reader # self.intensive_reader.to(self.intensive_reader.args.device) nbest_preds, score_diff = self.intensive_reader.predict( intensive_features, predict_examples, mode="retro_inference") # self.intensive_reader.to("cpu") # Combine the outputs of the submodules predictions, scores = self.rear_verifier(score_ext, score_diff, nbest_preds) outputs = (predictions, scores) # Add the outputs of the submodules if required if return_submodule_outputs: outputs += (score_ext, nbest_preds, score_diff) return outputs def evaluate(self, test_dataset: datasets.Dataset) -> dict: """ Evaluates the model on the given test dataset. Args: test_dataset (Dataset): The dataset containing the test examples and ground truth answers. Returns: dict: A dictionary containing the evaluation metrics. """ # Perform inference on the test dataset predictions, scores, score_ext, nbest_preds, score_diff = self.inference(test_dataset, return_submodule_outputs=True) # Extract ground truth answers ground_truths = test_dataset[C.ANSWER_COLUMN_NAME] formatted_predictions = [] for example, pred in zip(test_dataset, predictions): formatted_predictions.append({ 'id': example[C.ID_COLUMN_NAME], 'prediction_text': pred, 'no_answer_probability': 0.0 # Assuming no_answer_probability is 0 for simplicity }) formatted_references = [] for example in test_dataset: formatted_references.append({ 'id': example[C.ID_COLUMN_NAME], 'answers': example[C.ANSWER_COLUMN_NAME], }) # Return the evaluation metrics return compute_squad_v2(EvalPrediction(predictions=formatted_predictions, label_ids=formatted_references)) @property def null_score_diff_threshold(self): return self.args.null_score_diff_threshold @null_score_diff_threshold.setter def null_score_diff_threshold(self, val): self.args.null_score_diff_threshold = val @property def n_best_size(self): return self.args.n_best_size @n_best_size.setter def n_best_size(self, val): self.args.n_best_size = val @property def beta1(self): return self.rear_verifier.beta1 @beta1.setter def beta1(self, val): self.rear_verifier.beta1 = val @property def beta2(self): return self.rear_verifier.beta2 @beta2.setter def beta2(self, val): self.rear_verifier.beta2 = val @property def best_cof(self): return self.rear_verifier.best_cof @best_cof.setter def best_cof(self, val): self.rear_verifier.best_cof = val @property def rear_threshold(self): return self.rear_verifier.thresh @rear_threshold.setter def rear_threshold(self, val): self.rear_verifier.thresh = val