File size: 6,080 Bytes
eca534f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d16d8
eca534f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d16d8
eca534f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import uuid
import tqdm
import json
import traceback
from typing import Callable

from openfactcheck.lib.logger import logger
from openfactcheck.core.base import OpenFactCheck
from openfactcheck.core.state import FactCheckerState

class ResponseEvaluator:
    def __init__(self, ofc: OpenFactCheck):
        """
        Initialize the ResponseEvaluator object.
        """
        
        # Set the OpenFactCheck object
        self.ofc = ofc

    def persist_output(self, state: FactCheckerState, idx, solver_name, cont, sample_name=0):
        """
        Persist the output of the solver
        """
        result = {
            "idx": idx,
            "solver": solver_name,
            "continue": cont,
            "state": state.to_dict()
        }

        # Create the output path
        output_path = os.path.join(self.ofc.output_path, os.path.dirname(sample_name))
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        # Write the output to a file
        with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'a', encoding="utf-8") as f:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')

    def read_output(self, sample_name):
        """
        Read the output file for the given sample
        """
        with open(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'), 'r', encoding="utf-8") as f:
            return [json.loads(line) for line in f]
        
    def remove_output(self, sample_name):
        """
        Remove the output file for the given sample
        """
        os.remove(os.path.join(self.ofc.output_path, f'{sample_name}.jsonl'))

    def evaluate(self, response: str, question: str = None, callback: Callable = None, **kwargs):
        """
        Evaluate the response using the pipeline and return the output
        """

        # Check if sample_name is provided in kwargs else generate a random one
        sample_name = kwargs.get("sample_name", str(uuid.uuid4()))

        # Initialize the state
        solver_output = FactCheckerState(question=question, response=response)

        # Initialize the output name
        output_name = "response"
        for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()),
                                                            total=len(self.ofc.pipeline)):
            logger.info(f"Invoking solver: {idx}-{name}")
            logger.debug(f"State content: {solver_output}")
        
            try:
                # Solver input is the output of the previous solver
                solver_input = solver_output

                # Run the solver
                cont, solver_output = solver(solver_input, **kwargs)

                # Persist the output
                logger.debug(f"Latest result: {solver_output}")
                if callback:
                    callback(
                        index=idx,
                        sample_name=sample_name,
                        solver_name=name,
                        input_name=input_name,
                        output_name=output_name,
                        input=solver_input.__dict__,
                        output=solver_output.__dict__,
                        continue_run=cont
                    )
                
                self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
                
            except:
                logger.error(f"Error at {traceback.format_exc()}")
                cont = False
                output_name = input_name

            # Break if the solver returns False
            if not cont:
                logger.info(f"Break at {name}")
                break

        return solver_output.get(output_name)

    def evaluate_streaming(self, response: str, question: str = None, **kwargs):
        """
        Evaluate the response using the pipeline and stream the output
        """

        def evaluate_response():
            # Check if sample_name is provided in kwargs else generate a random one
            sample_name = kwargs.get("sample_name", str(uuid.uuid4()))

            # Initialize the state
            solver_output = FactCheckerState(question=question, response=response)

            # Initialize the output name
            output_name = "response"
            for idx, (name, (solver, input_name, output_name)) in tqdm.tqdm(enumerate(self.ofc.pipeline.items()),
                                                                total=len(self.ofc.pipeline)):
                logger.info(f"Invoking solver: {idx}-{name}")
                logger.debug(f"State content: {solver_output}")
            
                try:
                    # Solver input is the output of the previous solver
                    solver_input = solver_output

                    # Run the solver
                    cont, solver_output = solver(solver_input, **kwargs)

                    # Persist the output
                    logger.debug(f"Latest result: {solver_output}")
                    
                    # Stream the output
                    yield {
                        "index": idx,
                        "solver_name": name,
                        "input_name": input_name,
                        "output_name": output_name,
                        "input": solver_input.__dict__,
                        "output": solver_output.__dict__,
                        "continue_run": cont
                    }

                    self.persist_output(solver_output, idx, name, cont, sample_name=sample_name)
                    
                except:
                    logger.error(f"Error at {traceback.format_exc()}")
                    cont = False
                    output_name = input_name

                # Break if the solver returns False
                if not cont:
                    logger.info(f"Break at {name}")
                    break
        
        # Execute the generator if stream is True, otherwise process normally
        return evaluate_response()