File size: 3,066 Bytes
dbf7be3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08dc526
dbf7be3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# -*- coding:utf-8 -*-
import os
import logging
from . import bleu
from . import weighted_ngram_match
from . import syntax_match
from . import dataflow_match


def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25', kw_dir = ".", langso_dir = "."):
    """_summary_

    Args:
        predictions (list[str]): list of predictions
        references (list[str]): list of lists with references
        lang (str): ['java','js','c_sharp','php','go','python','ruby']
        tokenizer (callable): tokenizer function, Defaults to lambda s: s.split()
        params (str, optional): Defaults to '0.25,0.25,0.25,0.25'.
    """

    alpha, beta, gamma, theta = [float(x) for x in params.split(',')]

    # preprocess inputs
    references = [[x.strip() for x in ref] if type(ref) == list else [ref.strip()] for ref in references]
    hypothesis = [x.strip() for x in predictions]

    if not len(references) == len(hypothesis):
        raise ValueError

    # calculate ngram match (BLEU)
    if tokenizer is None:
        tokenizer = lambda s: s.split()

    tokenized_hyps = [tokenizer(x) for x in hypothesis]
    tokenized_refs = [[tokenizer(x) for x in reference]
                      for reference in references]

    ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)

    # calculate weighted ngram match
    keywords = [x.strip() for x in open(kw_dir + '/keywords/' + lang +
                                        '.txt', 'r', encoding='utf-8').readlines()]

    def make_weights(reference_tokens, key_word_list):
        return {token: 1 if token in key_word_list else 0.2
                for token in reference_tokens}
    tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]
                                    for reference_tokens in reference] for reference in tokenized_refs]

    weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(
        tokenized_refs_with_weights, tokenized_hyps)

    # calculate syntax match
    syntax_match_score = syntax_match.corpus_syntax_match(
        references, hypothesis, lang, langso_dir)

    # calculate dataflow match
    dataflow_match_score = dataflow_match.corpus_dataflow_match(
        references, hypothesis, lang, langso_dir)

    # print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.
        #   format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))

    code_bleu_score = alpha*ngram_match_score\
        + beta*weighted_ngram_match_score\
        + gamma*syntax_match_score\
        + theta*(dataflow_match_score or 1)

    # print('CodeBLEU score: ', code_bleu_score)

    return {
        'CodeBLEU': code_bleu_score,
        'ngram_match_score': ngram_match_score,
        'weighted_ngram_match_score': weighted_ngram_match_score,
        'syntax_match_score': syntax_match_score,
        'dataflow_match_score': dataflow_match_score
    }