File size: 2,970 Bytes
f4c3446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc52a82
f4c3446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc52a82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from vllm import LLM, SamplingParams
import multiprocessing
import time
import gc
import torch
import pdb
import sqlite3
from concurrent.futures import ThreadPoolExecutor
#from openai_call import query_azure_openai_chatgpt_chat
def label_transform(label):
    if label==1:
        return 'neutral'
    if label==0:
        return 'entailment'
    if label==2:
        return 'contradiction'
sampling_params = SamplingParams(temperature=0.0,max_tokens=600, top_p=0.95)
def valid_results_collect(model_path,valid_data,task):
   
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
#    multiprocessing.set_start_method('spawn')
    trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
    start_t=time.time()
    failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
    del trained_model
    end_t=time.time()
    print('time',start_t-end_t)
    gc.collect()  # Run garbage collection to free up memory
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()
    time.sleep(10)
    return failed_cases,correct_cases
def extract_answer_prediction_nli(predicted_output):
    sens=predicted_output.split('.')
    final_sens=[sen for sen in sens if 'final' in sen]
    for sen in final_sens:
        if extract_answer(sen):
            return extract_answer(sen)
    return
def extract_answer(text):
    if 'neutral' in text.lower():
        return 'neutral'
    if 'contradiction' in text.lower():
        return 'contradiction'
    if 'entailment' in text.lower():
        return 'entailment'
    return None
def process_batch(data_batch,trained_model,failed_cases,correct_cases):
    batch_prompts = [data['Input'] for data in data_batch]
    outputs = trained_model.generate(batch_prompts, sampling_params)
    
    labels=['entailment','contradiction','neutral']
    for data, output in zip(data_batch, outputs):
#        pdb.set_trace()
        predicted_output = output.outputs[0].text
        predicted_res = extract_answer_prediction_nli(predicted_output)
        label = extract_answer(data['Output'].split('is')[-1])
        print(predicted_res,label,'\n')
        if not predicted_res:
#            pdb.set_trace()
            
            predicted_res=predicted_output
        non_labels = [lbl for lbl in labels if lbl != label]
        if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
            failed_cases.append((data['Input'],predicted_res,label,data))
        else:
            correct_cases.append((data['Input'],predicted_res,label,data))
    return failed_cases,correct_cases
def nli_evaluation(trained_model,valid_data):
    failed_cases=[]
    correct_cases=[]
    batch_size=500
    batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
    for batch in batched_data:
        failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
    return failed_cases,correct_cases