from vllm import LLM, SamplingParams import multiprocessing import time import gc import torch import pdb import sqlite3 from concurrent.futures import ThreadPoolExecutor #from openai_call import query_azure_openai_chatgpt_chat def label_transform(label): if label==1: return 'neutral' if label==0: return 'entailment' if label==2: return 'contradiction' sampling_params = SamplingParams(temperature=0.0,max_tokens=600, top_p=0.95) def valid_results_collect(model_path,valid_data,task): torch.cuda.empty_cache() torch.cuda.ipc_collect() # multiprocessing.set_start_method('spawn') trained_model=LLM(model=model_path,gpu_memory_utilization=0.95) start_t=time.time() failed_cases,correct_cases=nli_evaluation(trained_model,valid_data) del trained_model end_t=time.time() print('time',start_t-end_t) gc.collect() # Run garbage collection to free up memory torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.cuda.synchronize() time.sleep(10) return failed_cases,correct_cases def extract_answer_prediction_nli(predicted_output): sens=predicted_output.split('.') final_sens=[sen for sen in sens if 'final' in sen] for sen in final_sens: if extract_answer(sen): return extract_answer(sen) return def extract_answer(text): if 'neutral' in text.lower(): return 'neutral' if 'contradiction' in text.lower(): return 'contradiction' if 'entailment' in text.lower(): return 'entailment' return None def process_batch(data_batch,trained_model,failed_cases,correct_cases): batch_prompts = [data['Input'] for data in data_batch] outputs = trained_model.generate(batch_prompts, sampling_params) labels=['entailment','contradiction','neutral'] for data, output in zip(data_batch, outputs): # pdb.set_trace() predicted_output = output.outputs[0].text predicted_res = extract_answer_prediction_nli(predicted_output) label = extract_answer(data['Output'].split('is')[-1]) print(predicted_res,label,'\n') if not predicted_res: # pdb.set_trace() predicted_res=predicted_output non_labels = [lbl for lbl in labels if lbl != label] if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels): failed_cases.append((data['Input'],predicted_res,label,data)) else: correct_cases.append((data['Input'],predicted_res,label,data)) return failed_cases,correct_cases def nli_evaluation(trained_model,valid_data): failed_cases=[] correct_cases=[] batch_size=500 batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)] for batch in batched_data: failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases) return failed_cases,correct_cases