File size: 1,064 Bytes
8894e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from zeno import distill, model, metric, ZenoOptions
from inspiredco.critique import Critique
import os

# from sentence_transformers import SentenceTransformer

# sentence_embed = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
client = Critique(api_key=os.environ["INSPIREDCO_API_KEY"])


@model
def pred_fns(name):
    def pred(df, ops):
        return df["translation"]
        # , sentence_embed.encode(df[ops.label_column].tolist())

    return pred


@distill
def bert_score(df, ops):
    eval_dict = df[["source", ops.output_column, "label"]].to_dict("records")
    for d in eval_dict:
        d["references"] = [d.pop("label")]
        d["target"] = d.pop(ops.output_column)

    result = client.evaluate(
        metric="bert_score", config={"model": "bert-base-uncased"}, dataset=eval_dict
    )

    return [round(r["value"], 6) for r in result["examples"]]


@metric
def avg_bert_score(df, ops: ZenoOptions):
    return df[ops.distill_columns["bert_score"]].mean()


@distill
def length(df, ops):
    return df[ops.data_column].str.len()