File size: 4,399 Bytes
eca534f
8360ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eca534f
 
 
8360ec7
eca534f
 
 
 
 
 
 
 
8360ec7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import sys
import os
import importlib

from openfactcheck.lib.logger import logger
from openfactcheck.core.state import FactCheckerState

# Global solver registry
SOLVER_REGISTRY = {}

class StandardTaskSolver:
    """
    A class to represent a standard task solver. A standard task solver is a
    class that implements a specific task in a fact-checking system. It
    receives a FactCheckerState object as input and returns a new
    FactCheckerState object as output.

    Parameters
    ----------
    args : dict
        A dictionary containing the arguments to be passed to the solver.
    """
    
    name: str = None
    input_name: str = None
    output_name: str = None
    global_config: dict = dict()

    def __init__(self, args: dict):
        self.logger = logger
        self.args = args

        logger.debug(self.args)

    def __call__(self, state: FactCheckerState, **kwargs) -> tuple[
        bool, FactCheckerState]:
        raise NotImplementedError

    @classmethod
    def build_solver(cls, args):
        raise NotImplementedError

    @property
    def input_name(self):
        return self.__class__.input_name

    @property
    def output_name(self):
        return self.__class__.output_name

    def __str__(self):
        return f'[name:"{self.__class__.name}", input: "{self.__class__.input_name}": output: "{self.__class__.output_name}"]'

class Solver:
    """
    Class to handle the registration and loading of solvers
    """
    def __init__(self):
        pass

    def register(name, input_name=None, output_name=None):
        def decorator(cls):
            """
            Decorator to register a solver class
            """

            # Check if the solver is already registered
            if name in SOLVER_REGISTRY:
                return SOLVER_REGISTRY[name]

            # Check if the solver class extends StandardTaskSolver
            if not issubclass(cls, StandardTaskSolver):
                logger.error(f"Solver '{name}' must extend StandardTaskSolver, got {cls.__name__}.")
                raise ValueError(f"Solver '{name}' must extend StandardTaskSolver, got {cls.__name__}.")
            
            # Register the solver
            SOLVER_REGISTRY[name] = cls
            cls.name = name
            cls.input_name = input_name
            cls.output_name = output_name

            logger.info(f"Solver '{name}' registered")
            return cls

        return decorator
    
    @staticmethod
    def load_from_directory(directory, namespace):
        """
        Load solvers from a directory
        """

        # Check if the directory exists
        for item in sorted(os.listdir(directory), 
                           key=lambda x: os.path.isdir(os.path.join(directory, x)), 
                           reverse=True):
            
            # Skip hidden files and directories
            if item.startswith('_') or item.startswith('.'):
                continue

            # Get the full path of the item
            full_path = os.path.join(directory, item)

            # Load the item
            if os.path.isdir(full_path):
                Solver.load_from_directory(full_path, namespace + '.' + item)
            else:
                Solver.load_from_file(full_path, namespace)
    
    @staticmethod
    def load_from_file(file_path, namespace):
        """
        Load a solver from a file
        """

        # Check if the file is a Python file
        if file_path.endswith(".py"):
            # Get the solver name
            solver_name = os.path.basename(file_path)[:-3]

            # Get the module name
            module_name = namespace + "." + solver_name

            # Log the full module name to debug
            logger.debug(f"Attempting to import {module_name} from {file_path}")

            # Import the module
            try:
                importlib.import_module(module_name)
                logger.debug(f"Successfully imported {module_name}")
            except Exception as e:
                logger.error(f"Failed to import {module_name}: {e}")
                raise Exception(f"Failed to import {module_name}: {e}")

            return module_name

    @staticmethod
    def load(path, namespace):
        if os.path.isdir(path):
            Solver.load_from_directory(path, namespace)
        else:
            Solver.load_from_file(path, namespace)
        return