codebleu / my_codebleu.py
dvitel's picture
allow references to be simple list
08dc526
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding:utf-8 -*-
import os
import logging
from . import bleu
from . import weighted_ngram_match
from . import syntax_match
from . import dataflow_match
def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25', kw_dir = ".", langso_dir = "."):
"""_summary_
Args:
predictions (list[str]): list of predictions
references (list[str]): list of lists with references
lang (str): ['java','js','c_sharp','php','go','python','ruby']
tokenizer (callable): tokenizer function, Defaults to lambda s: s.split()
params (str, optional): Defaults to '0.25,0.25,0.25,0.25'.
"""
alpha, beta, gamma, theta = [float(x) for x in params.split(',')]
# preprocess inputs
references = [[x.strip() for x in ref] if type(ref) == list else [ref.strip()] for ref in references]
hypothesis = [x.strip() for x in predictions]
if not len(references) == len(hypothesis):
raise ValueError
# calculate ngram match (BLEU)
if tokenizer is None:
tokenizer = lambda s: s.split()
tokenized_hyps = [tokenizer(x) for x in hypothesis]
tokenized_refs = [[tokenizer(x) for x in reference]
for reference in references]
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)
# calculate weighted ngram match
keywords = [x.strip() for x in open(kw_dir + '/keywords/' + lang +
'.txt', 'r', encoding='utf-8').readlines()]
def make_weights(reference_tokens, key_word_list):
return {token: 1 if token in key_word_list else 0.2
for token in reference_tokens}
tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]
for reference_tokens in reference] for reference in tokenized_refs]
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(
tokenized_refs_with_weights, tokenized_hyps)
# calculate syntax match
syntax_match_score = syntax_match.corpus_syntax_match(
references, hypothesis, lang, langso_dir)
# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(
references, hypothesis, lang, langso_dir)
# print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.
# format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))
code_bleu_score = alpha*ngram_match_score\
+ beta*weighted_ngram_match_score\
+ gamma*syntax_match_score\
+ theta*(dataflow_match_score or 1)
# print('CodeBLEU score: ', code_bleu_score)
return {
'CodeBLEU': code_bleu_score,
'ngram_match_score': ngram_match_score,
'weighted_ngram_match_score': weighted_ngram_match_score,
'syntax_match_score': syntax_match_score,
'dataflow_match_score': dataflow_match_score
}