File size: 1,794 Bytes
77d955f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

from llmware.prompts import Prompt


def load_rag_benchmark_tester_ds():

    # pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
    from datasets import load_dataset

    ds_name = "llmware/rag_instruct_benchmark_tester"

    dataset = load_dataset(ds_name)

    print("update: loading test dataset - ", dataset)

    test_set = []
    for i, samples in enumerate(dataset["train"]):
        test_set.append(samples)

        # to view test set samples
        # print("rag benchmark dataset test samples: ", i, samples)

    return test_set


def run_test(model_name, prompt_list):

    print("\nupdate: Starting RAG Benchmark Inference Test")

    prompter = Prompt().load_model(model_name,from_hf=True)

    for i, entries in enumerate(prompt_list):

        prompt = entries["query"]
        context = entries["context"]

        response = prompter.prompt_main(prompt,context=context,prompt_name="default_with_context", temperature=0.3)

        fc = prompter.evidence_check_numbers(response)
        sc = prompter.evidence_comparison_stats(response)
        sr = prompter.evidence_check_sources(response)

        print("\nupdate: model inference output - ", i, response["llm_response"])
        print("update: gold_answer              - ", i, entries["answer"])

        for entries in fc:
            print("update: fact check - ", entries["fact_check"])

        for entries in sc:
            print("update: comparison stats - ", entries["comparison_stats"])

        for entries in sr:
            print("update: sources - ", entries["source_review"])

    return 0


if __name__ == "__main__":

    core_test_set = load_rag_benchmark_tester_ds()

    model_name = "llmware/bling-stable-lm-3b-4e1t-v0"
    
    output = run_test(model_name, core_test_set)