|
import os |
|
import uuid |
|
import tqdm |
|
import json |
|
import traceback |
|
from typing import Callable |
|
|
|
from openfactcheck.lib.logger import logger |
|
from openfactcheck.core.base import OpenFactCheck |
|
from openfactcheck.core.state import FactCheckerState |
|
|
|
class ResponseEvaluator: |
|
def __init__(self, ofc: OpenFactCheck): |
|
""" |
|
Initialize the ResponseEvaluator object. |
|
""" |
|
|
|
|
|
self.ofc = ofc |
|
|
|
def persist_output(self, state: FactCheckerState, idx, solver_name, cont, sample_name=0): |
|
""" |
|
Persist the output of the solver |
|
""" |
|
result = { |
|
"idx": idx, |
|
"solver": solver_name, |
|
"continue": cont, |
|
"state": state.to_dict() |
|
} |
|
|
|
|
|
output_path = os.path.join(self.ofc.output_path, os.path.dirname(sample_name)) |
|
if not os.path.exists(output_path): |
|
os.makedirs(output_path) |
|
|
|
|
|
with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'a', encoding="utf-8") as f: |
|
f.write(json.dumps(result, ensure_ascii=False) + '\n') |
|
|
|
def read_output(self, sample_name): |
|
""" |
|
Read the output file for the given sample |
|
""" |
|
with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'r', encoding="utf-8") as f: |
|
return [json.loads(line) for line in f] |
|
|
|
def remove_output(self, sample_name): |
|
""" |
|
Remove the output file for the given sample |
|
""" |
|
os.remove(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl')) |
|
|
|
def evaluate(self, response: str, question: str = None, callback: Callable = None, **kwargs): |
|
""" |
|
Evaluate the response using the pipeline and return the output |
|
""" |
|
|
|
|
|
sample_name = kwargs.get("sample_name", str(uuid.uuid4())) |
|
|
|
|
|
solver_output = FactCheckerState(question=question, response=response) |
|
|
|
|
|
output_name = "response" |
|
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()), |
|
total=len(self.ofc.pipeline)): |
|
logger.info(f"Invoking solver: {idx}-{name}") |
|
logger.debug(f"State content: {solver_output}") |
|
|
|
try: |
|
|
|
solver_input = solver_output |
|
|
|
|
|
cont, solver_output = solver(solver_input, **kwargs) |
|
|
|
|
|
logger.debug(f"Latest result: {solver_output}") |
|
if callback: |
|
callback( |
|
index=idx, |
|
sample_name=sample_name, |
|
solver_name=name, |
|
input_name=input_name, |
|
output_name=output_name, |
|
input=solver_input.__dict__, |
|
output=solver_output.__dict__, |
|
continue_run=cont |
|
) |
|
|
|
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name) |
|
|
|
except: |
|
logger.error(f"Error at {traceback.format_exc()}") |
|
cont = False |
|
output_name = input_name |
|
|
|
|
|
if not cont: |
|
logger.info(f"Break at {name}") |
|
break |
|
|
|
return solver_output.get(output_name) |
|
|
|
def evaluate_streaming(self, response: str, question: str = None, **kwargs): |
|
""" |
|
Evaluate the response using the pipeline and stream the output |
|
""" |
|
|
|
def evaluate_response(): |
|
|
|
sample_name = kwargs.get("sample_name", str(uuid.uuid4())) |
|
|
|
|
|
solver_output = FactCheckerState(question=question, response=response) |
|
|
|
|
|
output_name = "response" |
|
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()), |
|
total=len(self.ofc.pipeline)): |
|
logger.info(f"Invoking solver: {idx}-{name}") |
|
logger.debug(f"State content: {solver_output}") |
|
|
|
try: |
|
|
|
solver_input = solver_output |
|
|
|
|
|
cont, solver_output = solver(solver_input, **kwargs) |
|
|
|
|
|
logger.debug(f"Latest result: {solver_output}") |
|
|
|
|
|
yield { |
|
"index": idx, |
|
"solver_name": name, |
|
"input_name": input_name, |
|
"output_name": output_name, |
|
"input": solver_input.__dict__, |
|
"output": solver_output.__dict__, |
|
"continue_run": cont |
|
} |
|
|
|
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name) |
|
|
|
except: |
|
logger.error(f"Error at {traceback.format_exc()}") |
|
cont = False |
|
output_name = input_name |
|
|
|
|
|
if not cont: |
|
logger.info(f"Break at {name}") |
|
break |
|
|
|
|
|
return evaluate_response() |