import numpy as np import torch from evaluate import load as load_metric from sklearn.metrics import mean_squared_error from tqdm.auto import tqdm MAX_TARGET_LENGTH = 128 # load evaluation metrics sacrebleu = load_metric('sacrebleu') rouge = load_metric('rouge') meteor = load_metric('meteor') bertscore = load_metric('bertscore') # use gpu if it's available device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') def flatten_list(l): """ Utility function to convert a list of lists into a flattened list Params: l (list of lists): list to be flattened Returns: A flattened list with the elements of the original list """ return [item for sublist in l for item in sublist] def parse_float(value): """ Utility function to parse a string into a float Params: value (string): value to be converted to float Returns: The float representation of the given string, or -1 if the string could not be converted to a float """ try: float_value = float(value) return float_value except ValueError: return -1 def extract_scores(predictions): """ Utility function to extract the scores from the predictions of the model Params: predictions (list): complete model predictions Returns: scores (list): extracted scores from the model's predictions """ scores = [] # iterate through predictions and try to extract predicted score; # if score could not be extracted, set it to -1 for pred in predictions: try: score_string = pred.split('Feedback:', 1)[0].strip() score = parse_float(score_string) except IndexError: try: score_string = pred.split(' ', 1)[0].strip() score = parse_float(score_string) except IndexError: score = -1 scores.append(score) return scores def extract_feedback(predictions): """ Utility function to extract the feedback from the predictions of the model Params: predictions (list): complete model predictions Returns: feedback (list): extracted feedback from the model's predictions """ feedback = [] # iterate through predictions and try to extract predicted feedback for pred in predictions: try: fb = pred.split(':', 1)[1] except IndexError: try: fb = pred.split(' ', 1)[1] except IndexError: fb = pred feedback.append(fb.strip()) return feedback def compute_mse(predictions, labels): """ Utility function to compute the mean squared error of the score predictions in relation to the golden label scores Params: predictions (list): model score predictions labels (list): golden label scores Returns: (float, int): mse of valid samples and number of invalid samples """ # get indexes of valid score predictions # (i.e., where the score is greater than zero) idx = np.where(np.array(predictions) > 0) # get size of the golden labels list and of # the valid predictions array labels_size = len(labels) valid_predictions_size = idx[0].size # only compute mse if valid score predictions were generated, # otherwise set mse to 1 if valid_predictions_size > 0: # calculate mse from labels and predictions valid_predictions = np.array(predictions)[idx] score_labels = np.array(labels)[idx] mse = mean_squared_error(score_labels, valid_predictions) # cap mse at 1 if mse > 1: return 1, labels_size - valid_predictions_size # return computed mse and number of invalid samples return mse, labels_size - valid_predictions_size else: return 1, labels_size - valid_predictions_size def compute_metrics(predictions, labels): """ Compute evaluation metrics from the predictions of the model Params: predictions (list): complete model predictions labels (list): golden labels (previously tokenized) Returns: results (dict): dictionary with the computed evaluation metrics """ # extract feedback and labels from the model's predictions predicted_feedback = extract_feedback(predictions) predicted_scores = extract_scores(predictions) # extract feedback and labels from the golden labels reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels] reference_scores = [float(x.split('Feedback:', 1)[0].strip()) for x in labels] # compute HF metrics sacrebleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score'] rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2'] meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor'] bert_score = bertscore.compute( predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True) # compute mse of score predictions mse, _ = compute_mse(predicted_scores, reference_scores) results = { 'sacrebleu': sacrebleu_score, 'rouge': rouge_score, 'meteor': meteor_score, 'bert_score': np.array(bert_score['f1']).mean().item(), 'mse': mse } return results def evaluate(model, tokenizer, dataloader): """ Evaluate model on the given dataset Params: model (PreTrainedModel): seq2seq model tokenizer (PreTrainedTokenizer): tokenizer from HuggingFace dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation Returns: results (dict): dictionary with the computed evaluation metrics predictions (list): list of the decoded predictions of the model """ decoded_preds, decoded_labels = [], [] model.eval() # iterate through batchs in the dataloader for batch in tqdm(dataloader): with torch.no_grad(): batch = {k: v.to(device) for k, v in batch.items()} # generate tokens from batch generated_tokens = model.generate( batch['input_ids'], attention_mask=batch['attention_mask'], max_length=MAX_TARGET_LENGTH ) # get golden labels from batch labels_batch = batch['labels'] # decode model predictions and golden labels decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True) decoded_preds.append(decoded_preds_batch) decoded_labels.append(decoded_labels_batch) # convert predictions and golden labels into flattened lists predictions = flatten_list(decoded_preds) labels = flatten_list(decoded_labels) # compute metrics based on predictions and golden labels results = compute_metrics(predictions, labels) return results, predictions