import unittest
from unittest.mock import patch

import pandas as pd

import src.backend.model_operations as model_operations


class TestEvaluator(unittest.TestCase):

    def setUp(self):
        self.model_path = "test_model"

    @patch("src.backend.model_operations.load_evaluation_model")
    def test_init(self, mock_load_evaluation_model):
        model_operations.EvaluationModel(self.model_path)
        mock_load_evaluation_model.assert_called_once_with(self.model_path)

    @patch("src.backend.model_operations.load_evaluation_model")
    def test_evaluate_hallucination(self, mock_load_evaluation_model):
        model = model_operations.EvaluationModel(self.model_path)
        df = pd.DataFrame({'source': ['source1', 'source2'], 'summary': ['summary1', 'summary2']})

        mock_load_evaluation_model.return_value.predict.return_value = [0.8, 0.2]

        scores = model.evaluate_hallucination(df)
        self.assertEqual(scores, [0.8, 0.2])

    @patch("src.backend.model_operations.load_evaluation_model")
    def test_evaluate_hallucination_exception(self, mock_load_evaluation_model):
        model = model_operations.EvaluationModel(self.model_path)
        df = pd.DataFrame({'source': ['source1', 'source2'], 'summary': ['summary1', 'summary2']})

        mock_load_evaluation_model.return_value.predict.side_effect = Exception("Test exception")

        with self.assertRaises(Exception):
            scores = model.evaluate_hallucination(df)

    @patch("src.backend.model_operations.load_evaluation_model")
    def test_compute_accuracy(self, mock_load_evaluation_model):
        model = model_operations.EvaluationModel(self.model_path)
        model.scores = [0.8, 0.2]

        accuracy = model.compute_accuracy()
        expected_accuracy = 50.0
        self.assertEqual(accuracy, expected_accuracy)


class TestLoadEvaluationModel(unittest.TestCase):

    @patch("src.backend.model_operations.CrossEncoder")
    def test_load_evaluation_model(self, mock_cross_encoder):
        model_path = 'test_model_path'
        model_operations.load_evaluation_model(model_path)
        mock_cross_encoder.assert_called_once_with(model_path)


if __name__ == '__main__':
    unittest.main()