File size: 6,080 Bytes
eca534f 48d16d8 eca534f 48d16d8 eca534f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import uuid
import tqdm
import json
import traceback
from typing import Callable
from openfactcheck.lib.logger import logger
from openfactcheck.core.base import OpenFactCheck
from openfactcheck.core.state import FactCheckerState
class ResponseEvaluator:
def __init__(self, ofc: OpenFactCheck):
"""
Initialize the ResponseEvaluator object.
"""
# Set the OpenFactCheck object
self.ofc = ofc
def persist_output(self, state: FactCheckerState, idx, solver_name, cont, sample_name=0):
"""
Persist the output of the solver
"""
result = {
"idx": idx,
"solver": solver_name,
"continue": cont,
"state": state.to_dict()
}
# Create the output path
output_path = os.path.join(self.ofc.output_path, os.path.dirname(sample_name))
if not os.path.exists(output_path):
os.makedirs(output_path)
# Write the output to a file
with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'a', encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
def read_output(self, sample_name):
"""
Read the output file for the given sample
"""
with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'r', encoding="utf-8") as f:
return [json.loads(line) for line in f]
def remove_output(self, sample_name):
"""
Remove the output file for the given sample
"""
os.remove(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'))
def evaluate(self, response: str, question: str = None, callback: Callable = None, **kwargs):
"""
Evaluate the response using the pipeline and return the output
"""
# Check if sample_name is provided in kwargs else generate a random one
sample_name = kwargs.get("sample_name", str(uuid.uuid4()))
# Initialize the state
solver_output = FactCheckerState(question=question, response=response)
# Initialize the output name
output_name = "response"
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()),
total=len(self.ofc.pipeline)):
logger.info(f"Invoking solver: {idx}-{name}")
logger.debug(f"State content: {solver_output}")
try:
# Solver input is the output of the previous solver
solver_input = solver_output
# Run the solver
cont, solver_output = solver(solver_input, **kwargs)
# Persist the output
logger.debug(f"Latest result: {solver_output}")
if callback:
callback(
index=idx,
sample_name=sample_name,
solver_name=name,
input_name=input_name,
output_name=output_name,
input=solver_input.__dict__,
output=solver_output.__dict__,
continue_run=cont
)
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
except:
logger.error(f"Error at {traceback.format_exc()}")
cont = False
output_name = input_name
# Break if the solver returns False
if not cont:
logger.info(f"Break at {name}")
break
return solver_output.get(output_name)
def evaluate_streaming(self, response: str, question: str = None, **kwargs):
"""
Evaluate the response using the pipeline and stream the output
"""
def evaluate_response():
# Check if sample_name is provided in kwargs else generate a random one
sample_name = kwargs.get("sample_name", str(uuid.uuid4()))
# Initialize the state
solver_output = FactCheckerState(question=question, response=response)
# Initialize the output name
output_name = "response"
for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()),
total=len(self.ofc.pipeline)):
logger.info(f"Invoking solver: {idx}-{name}")
logger.debug(f"State content: {solver_output}")
try:
# Solver input is the output of the previous solver
solver_input = solver_output
# Run the solver
cont, solver_output = solver(solver_input, **kwargs)
# Persist the output
logger.debug(f"Latest result: {solver_output}")
# Stream the output
yield {
"index": idx,
"solver_name": name,
"input_name": input_name,
"output_name": output_name,
"input": solver_input.__dict__,
"output": solver_output.__dict__,
"continue_run": cont
}
self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
except:
logger.error(f"Error at {traceback.format_exc()}")
cont = False
output_name = input_name
# Break if the solver returns False
if not cont:
logger.info(f"Break at {name}")
break
# Execute the generator if stream is True, otherwise process normally
return evaluate_response() |