File size: 2,999 Bytes
421645e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding:utf-8 -*-
import os
import logging
from . import bleu
from . import weighted_ngram_match
from . import syntax_match
from . import dataflow_match
def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25'):
"""_summary_
Args:
predictions (list[str]): list of predictions
references (list[str]): list of lists with references
lang (str): ['java','js','c_sharp','php','go','python','ruby']
tokenizer (callable): tokenizer function, Defaults to lambda s: s.split()
params (str, optional): Defaults to '0.25,0.25,0.25,0.25'.
"""
alpha, beta, gamma, theta = [float(x) for x in params.split(',')]
# preprocess inputs
references = [[x.strip() for x in ref] for ref in references]
hypothesis = [x.strip() for x in predictions]
if not len(references) == len(hypothesis):
raise ValueError
# calculate ngram match (BLEU)
if tokenizer is None:
tokenizer = lambda s: s.split()
tokenized_hyps = [tokenizer(x) for x in hypothesis]
tokenized_refs = [[tokenizer(x) for x in reference]
for reference in references]
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)
# calculate weighted ngram match
keywords = [x.strip() for x in open(os.path.abspath(os.path.dirname(__file__)) + '/keywords/' + lang +
'.txt', 'r', encoding='utf-8').readlines()]
def make_weights(reference_tokens, key_word_list):
return {token: 1 if token in key_word_list else 0.2
for token in reference_tokens}
tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]
for reference_tokens in reference] for reference in tokenized_refs]
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(
tokenized_refs_with_weights, tokenized_hyps)
# calculate syntax match
syntax_match_score = syntax_match.corpus_syntax_match(
references, hypothesis, lang)
# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(
references, hypothesis, lang)
# print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.
# format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))
code_bleu_score = alpha*ngram_match_score\
+ beta*weighted_ngram_match_score\
+ gamma*syntax_match_score\
+ theta*dataflow_match_score
# print('CodeBLEU score: ', code_bleu_score)
return {
'CodeBLEU': code_bleu_score,
'ngram_match_score': ngram_match_score,
'weighted_ngram_match_score': weighted_ngram_match_score,
'syntax_match_score': syntax_match_score,
'dataflow_match_score': dataflow_match_score
}
|