File size: 2,970 Bytes
f4c3446 dc52a82 f4c3446 dc52a82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from vllm import LLM, SamplingParams
import multiprocessing
import time
import gc
import torch
import pdb
import sqlite3
from concurrent.futures import ThreadPoolExecutor
#from openai_call import query_azure_openai_chatgpt_chat
def label_transform(label):
if label==1:
return 'neutral'
if label==0:
return 'entailment'
if label==2:
return 'contradiction'
sampling_params = SamplingParams(temperature=0.0,max_tokens=600, top_p=0.95)
def valid_results_collect(model_path,valid_data,task):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# multiprocessing.set_start_method('spawn')
trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
start_t=time.time()
failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
del trained_model
end_t=time.time()
print('time',start_t-end_t)
gc.collect() # Run garbage collection to free up memory
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
time.sleep(10)
return failed_cases,correct_cases
def extract_answer_prediction_nli(predicted_output):
sens=predicted_output.split('.')
final_sens=[sen for sen in sens if 'final' in sen]
for sen in final_sens:
if extract_answer(sen):
return extract_answer(sen)
return
def extract_answer(text):
if 'neutral' in text.lower():
return 'neutral'
if 'contradiction' in text.lower():
return 'contradiction'
if 'entailment' in text.lower():
return 'entailment'
return None
def process_batch(data_batch,trained_model,failed_cases,correct_cases):
batch_prompts = [data['Input'] for data in data_batch]
outputs = trained_model.generate(batch_prompts, sampling_params)
labels=['entailment','contradiction','neutral']
for data, output in zip(data_batch, outputs):
# pdb.set_trace()
predicted_output = output.outputs[0].text
predicted_res = extract_answer_prediction_nli(predicted_output)
label = extract_answer(data['Output'].split('is')[-1])
print(predicted_res,label,'\n')
if not predicted_res:
# pdb.set_trace()
predicted_res=predicted_output
non_labels = [lbl for lbl in labels if lbl != label]
if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
failed_cases.append((data['Input'],predicted_res,label,data))
else:
correct_cases.append((data['Input'],predicted_res,label,data))
return failed_cases,correct_cases
def nli_evaluation(trained_model,valid_data):
failed_cases=[]
correct_cases=[]
batch_size=500
batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
for batch in batched_data:
failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
return failed_cases,correct_cases |