|
from vllm import LLM, SamplingParams |
|
import multiprocessing |
|
import time |
|
import gc |
|
import torch |
|
import pdb |
|
import sqlite3 |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
def label_transform(label): |
|
if label==1: |
|
return 'neutral' |
|
if label==0: |
|
return 'entailment' |
|
if label==2: |
|
return 'contradiction' |
|
sampling_params = SamplingParams(temperature=0.0,max_tokens=600, top_p=0.95) |
|
def valid_results_collect(model_path,valid_data,task): |
|
|
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
trained_model=LLM(model=model_path,gpu_memory_utilization=0.95) |
|
start_t=time.time() |
|
failed_cases,correct_cases=nli_evaluation(trained_model,valid_data) |
|
del trained_model |
|
end_t=time.time() |
|
print('time',start_t-end_t) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
torch.cuda.synchronize() |
|
time.sleep(10) |
|
return failed_cases,correct_cases |
|
def extract_answer_prediction_nli(predicted_output): |
|
sens=predicted_output.split('.') |
|
final_sens=[sen for sen in sens if 'final' in sen] |
|
for sen in final_sens: |
|
if extract_answer(sen): |
|
return extract_answer(sen) |
|
return |
|
def extract_answer(text): |
|
if 'neutral' in text.lower(): |
|
return 'neutral' |
|
if 'contradiction' in text.lower(): |
|
return 'contradiction' |
|
if 'entailment' in text.lower(): |
|
return 'entailment' |
|
return None |
|
def process_batch(data_batch,trained_model,failed_cases,correct_cases): |
|
batch_prompts = [data['Input'] for data in data_batch] |
|
outputs = trained_model.generate(batch_prompts, sampling_params) |
|
|
|
labels=['entailment','contradiction','neutral'] |
|
for data, output in zip(data_batch, outputs): |
|
|
|
predicted_output = output.outputs[0].text |
|
predicted_res = extract_answer_prediction_nli(predicted_output) |
|
label = extract_answer(data['Output'].split('is')[-1]) |
|
print(predicted_res,label,'\n') |
|
if not predicted_res: |
|
|
|
|
|
predicted_res=predicted_output |
|
non_labels = [lbl for lbl in labels if lbl != label] |
|
if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels): |
|
failed_cases.append((data['Input'],predicted_res,label,data)) |
|
else: |
|
correct_cases.append((data['Input'],predicted_res,label,data)) |
|
return failed_cases,correct_cases |
|
def nli_evaluation(trained_model,valid_data): |
|
failed_cases=[] |
|
correct_cases=[] |
|
batch_size=500 |
|
batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)] |
|
for batch in batched_data: |
|
failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases) |
|
return failed_cases,correct_cases |