|
import logging |
|
import copy |
|
import pdb |
|
import math |
|
import os |
|
import json |
|
import yaml |
|
import time |
|
import re |
|
from typing import List, Dict |
|
|
|
from factool.utils.base.pipeline import pipeline |
|
from factool.code.helper.postprocess import PostProcessor |
|
from factool.code.helper.execution import evaluate_test_cases_multi_solution |
|
from factool.utils.utils_json import CustomJSONEncoder |
|
|
|
class code_pipeline(pipeline): |
|
def __init__(self, foundation_model, multi_solution_cnt, testcases_input_cnt): |
|
super().__init__('code', foundation_model) |
|
|
|
self.multi_solution_cnt = multi_solution_cnt |
|
self.testcases_input_cnt = testcases_input_cnt |
|
|
|
with open(os.path.join(self.prompts_path, "query_generation.yaml"), 'r') as file: |
|
data = yaml.load(file, Loader=yaml.FullLoader) |
|
self.query_generation_prompt = data['code'] |
|
|
|
async def _testcases_input_generation(self, batch, testcases_input_cnt): |
|
messages_list = [] |
|
if self.company == 'openai': |
|
messages_list = [ |
|
[ |
|
{"role": "system", "content": self.query_generation_prompt['system']}, |
|
{"role": "user", |
|
"content": |
|
self.query_generation_prompt[ |
|
'user_testcases_' + str(testcases_input_cnt) |
|
].format(input_question=sample['prompt'], |
|
entry_point=sample['entry_point']) |
|
}, |
|
] |
|
for sample in batch |
|
] |
|
elif self.company == 'anthropic': |
|
messages_list = [self.query_generation_prompt[ |
|
'user_testcases_' + str(testcases_input_cnt) |
|
].format(input_question=sample['prompt'], |
|
entry_point=sample['entry_point']) |
|
for sample in batch] |
|
return await self.chat.async_run(messages_list, Dict) |
|
|
|
async def _multi_solution_generation(self, batch, multi_solution_cnt): |
|
bsize = 15 |
|
messages_list = [ |
|
[ |
|
{"role": "system", "content": self.query_generation_prompt['system']}, |
|
{"role": "user", "content": self.query_generation_prompt[ |
|
'user_solutions'].format(input_question=sample['prompt'], |
|
entry_point=sample['entry_point'])}, |
|
] |
|
for sample in batch |
|
] |
|
|
|
final_messages_list = [copy.deepcopy(messages) |
|
for messages in messages_list |
|
for _ in range(multi_solution_cnt) |
|
] |
|
|
|
responses = [] |
|
for i in range(0, len(final_messages_list), bsize): |
|
batch = final_messages_list[i:i + bsize] |
|
responses += await self.chat.async_run(batch, Dict) |
|
|
|
|
|
responses_split = [responses[i:i + multi_solution_cnt] |
|
for i in range(0, len(responses), |
|
multi_solution_cnt)] |
|
|
|
|
|
multi_solutions = [] |
|
for solutions in responses_split: |
|
key_names = [f"python_solution_{i}" |
|
for i in range(1, multi_solution_cnt + 1)] |
|
new_element = {key: solutions[i]['python_solution'] |
|
if solutions[i] != None else "None" for i, key in enumerate(key_names)} |
|
multi_solutions.append(new_element) |
|
|
|
return multi_solutions |
|
|
|
async def run_with_tool_live(self, batch, batch_size): |
|
testcases_input = await self._testcases_input_generation(batch, self.testcases_input_cnt) |
|
multi_solutions = await self._multi_solution_generation(batch, self.multi_solution_cnt) |
|
|
|
if testcases_input == None or multi_solutions == None: |
|
return None |
|
|
|
responses = [] |
|
for i in range(batch_size): |
|
response = {'testcases_input': [], |
|
'multi_solutions': [], 'with_tool_classification': "None"} |
|
try: |
|
response['testcases_input'] = list(testcases_input[i].values()) |
|
|
|
|
|
response['multi_solutions']\ |
|
= [multi_solutions[i][f'python_solution_{j}'] |
|
for j in range(1, self.multi_solution_cnt + 1)] +\ |
|
[batch[i]['completion']] |
|
except: |
|
response['testcases_input'] = ["None"] * self.testcases_input_cnt |
|
response['multi_solutions'] = ["None"] * (self.multi_solution_cnt + 1) |
|
|
|
exec_result = evaluate_test_cases_multi_solution( |
|
batch[i]['prompt'], response['testcases_input'], |
|
response['multi_solutions'], timeout=0.1) |
|
response['exec_result'] = exec_result |
|
|
|
response['with_tool_classification'] = True |
|
|
|
for testcase_result in exec_result: |
|
|
|
if isinstance(testcase_result[-1], str) \ |
|
and testcase_result[-1].startswith('FAILURE'): |
|
response['with_tool_classification'] = False |
|
|
|
|
|
else: |
|
failure_indices = [ |
|
i for i, res in enumerate(testcase_result[:-1]) |
|
if isinstance(res, str) and res.startswith('FAILURE')] |
|
testcase_result = [ |
|
res for i, res in enumerate(testcase_result) |
|
if i not in failure_indices] |
|
|
|
try: |
|
if testcase_result[:-1].count(testcase_result[-1]) \ |
|
< math.ceil(len(testcase_result) / 2): |
|
response['with_tool_classification'] = False |
|
|
|
except: |
|
response['with_tool_classification'] = False |
|
|
|
responses.append(response) |
|
|
|
return responses |
|
|
|
async def run_with_tool_api_call(self, prompts, responses, entry_points): |
|
|
|
|
|
claims = [] |
|
for i, response in enumerate(responses): |
|
if "```python" in response: |
|
match = re.search(r"```python\n(.*?)\n```", response, re.DOTALL) |
|
if match: |
|
claims.append(match.group(1)) |
|
else: |
|
claims.append("") |
|
elif "```" in response: |
|
match = re.search(r"```\n(.*?)\n```", response, re.DOTALL) |
|
if match: |
|
claims.append(match.group(1)) |
|
else: |
|
claims.append("") |
|
else: |
|
claims.append(response) |
|
|
|
batch_size = 5 |
|
num_batches = math.ceil(len(prompts) / batch_size) |
|
|
|
self.sample_list = [ |
|
{"prompt": prompt, "response": response, |
|
"entry_point": entry_point, "completion": claim, |
|
"category": 'code'} |
|
for prompt, response, entry_point, claim |
|
in zip(prompts, responses, entry_points, claims)] |
|
|
|
for i in range(num_batches): |
|
print(i) |
|
batch_start = i * batch_size |
|
batch_end = min((i + 1) * batch_size, len(responses)) |
|
|
|
responses_returned = await self.run_with_tool_live(self.sample_list[batch_start: batch_end], batch_end - batch_start) |
|
|
|
for j, response_returned in enumerate(responses_returned): |
|
index = batch_start + j |
|
self.sample_list[index].update({ |
|
'claim': self.sample_list[index]['completion'], |
|
'testcases_queries': response_returned['testcases_input'], |
|
'potential_solutions_queries': response_returned['multi_solutions'], |
|
'exec_results': response_returned['exec_result'], |
|
'claim_level_factuality': response_returned['with_tool_classification'], |
|
'response_level_factuality': response_returned['with_tool_classification'] |
|
}) |
|
del self.sample_list[index]["completion"] |
|
|
|
return self.sample_list |
|
|
|
async def run_with_tool_dataset(self, annotated_dataset_path: str, with_tool_classified_dataset_path: str, rerun: bool = False, rerun_indices: list = []): |
|
data_path = with_tool_classified_dataset_path if rerun else annotated_dataset_path |
|
with open(data_path, 'r') as f: |
|
data = [json.loads(line) for line in f] |
|
self.sample_list = data |
|
rerun_elements = self.sample_list if not rerun else [self.sample_list[i] for i in rerun_indices] |
|
|
|
batch_size = 5 |
|
num_batches = math.ceil(len(rerun_elements) / batch_size) |
|
|
|
for i in range(num_batches): |
|
print(i) |
|
batch_start = i * batch_size |
|
batch_end = min((i + 1) * batch_size, len(rerun_elements)) |
|
|
|
responses = await self.run_with_tool_live(rerun_elements[batch_start:batch_end], batch_end - batch_start) |
|
|
|
for j, response in enumerate(responses): |
|
index = batch_start + j if not rerun else rerun_indices[batch_start + j] |
|
self.sample_list[index]['with_tool_classification'] = response['with_tool_classification'] if response is not None else 'None' |
|
if response is not None: |
|
self.sample_list[index].update({ |
|
'testcases_input': response['testcases_input'], |
|
'multi_solutions': response['multi_solutions'], |
|
'exec_result': response['exec_result'] |
|
}) |
|
|
|
|
|
with open(with_tool_classified_dataset_path, 'w') as f: |
|
for item in self.sample_list: |
|
try: |
|
json_str = json.dumps(item, cls=CustomJSONEncoder) |
|
except: |
|
continue |
|
f.write(json_str + '\n') |
|
|
|
async def run_self_check_live(self, fewshot, batch): |
|
user_prompt_key = 'user_3_shot_CoT' if fewshot else 'user_zero_shot_CoT' |
|
messages_list = [ |
|
[ |
|
{"role": "system", "content": self.self_check_prompt['system']}, |
|
{"role": "user", "content": self.self_check_prompt[user_prompt_key].format(input_question=response['prompt'], input_solution=response['completion'])}, |
|
] |
|
for response in batch |
|
] |
|
return await self.chat.async_run(messages_list, Dict) |
|
|
|
async def run_self_check_dataset(self, annotated_dataset_path: str, self_check_classified_dataset_path: str, fewshot: bool = False, rerun: bool = False, rerun_indices: list = []): |
|
if rerun == False: |
|
with open(annotated_dataset_path, 'r') as f: |
|
self.sample_list = [json.loads(line) for line in f] |
|
rerun_elements = self.sample_list |
|
else: |
|
with open(self_check_classified_dataset_path, 'r') as f: |
|
self.sample_list = [json.loads(line) for line in f] |
|
rerun_elements = [self.sample_list[i] for i in rerun_indices] |
|
|
|
batch_size = 5 |
|
num_batches = math.ceil(len(rerun_elements) / batch_size) |
|
|
|
for i in range(num_batches): |
|
print(i) |
|
batch_start = i * batch_size |
|
batch_end = (i + 1) * batch_size |
|
batch = rerun_elements[batch_start:batch_end] |
|
|
|
responses = await self.run_self_check_live(fewshot, batch) |
|
for j, response in enumerate(responses): |
|
index = batch_start + j if rerun == False else rerun_indices[batch_start + j] |
|
self.sample_list[index]['self_check_classification'] = response.get('factuality', 'None') if response is not None else 'None' |
|
self.sample_list[index]['self_check_reasoning'] = response.get('reasoning', 'None') if response is not None else 'None' |
|
|
|
|
|
with open(self_check_classified_dataset_path, 'w') as f: |
|
for item in self.sample_list: |
|
json_str = json.dumps(item) |
|
f.write(json_str + '\n') |