File size: 38,905 Bytes
b98ea00
 
 
 
 
 
 
 
 
 
 
 
 
6ce5fee
b98ea00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ce5fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e33acd6
6ce5fee
 
bfa2d3c
6ce5fee
 
 
 
 
b98ea00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ce5fee
de31a83
b98ea00
6ce5fee
b98ea00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15aee67
 
 
 
 
 
 
 
 
 
 
b98ea00
 
 
15aee67
b98ea00
 
 
 
15aee67
 
 
b98ea00
15aee67
b98ea00
15aee67
b98ea00
15aee67
b98ea00
15aee67
 
 
 
b98ea00
 
15aee67
b98ea00
 
 
 
 
 
 
 
15aee67
 
b98ea00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
import os
import re
import json
import torch

import openai
import logging
import asyncio
import aiohttp
import pandas as pd
import numpy as np
import evaluate
import qdrant_client
from pypdf import PdfReader
from pydantic import BaseModel, Field
from typing import Any, List, Tuple, Set, Dict, Optional, Union
from sklearn.metrics.pairwise import cosine_similarity

from unstructured.partition.pdf import partition_pdf

import llama_index
from llama_index import PromptTemplate
from llama_index.retrievers import VectorIndexRetriever, BaseRetriever, BM25Retriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index import get_response_synthesizer
from llama_index.schema import NodeWithScore
from llama_index.query_engine import RetrieverQueryEngine
from llama_index import VectorStoreIndex, ServiceContext
from llama_index.embeddings import OpenAIEmbedding
from llama_index.llms import HuggingFaceLLM
import requests
from llama_index.llms import (
    CustomLLM,
    CompletionResponse,
    CompletionResponseGen,
    LLMMetadata,
)
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.llms.base import llm_completion_callback
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.storage.storage_context import StorageContext
from llama_index.postprocessor import SentenceTransformerRerank, LLMRerank

from tempfile import NamedTemporaryFile
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Create a logger object
logger = logging.getLogger(__name__)

class ConfigManager:
    """
    A class to manage loading and accessing configuration settings.

    Attributes:
        config (dict): Dictionary to hold configuration settings.

    Methods:
        load_config(config_path: str): Loads the configuration from a given JSON file.
        get_config_value(key: str): Retrieves a specific configuration value.
    """

    def __init__(self):
        self.configs = {}
        
    def load_config(self, config_name: str, config_path: str) -> None:
        """
        Loads configuration settings from a specified JSON file into a named configuration.

        Args:
            config_name (str): The name to assign to this set of configurations.
            config_path (str): The path to the configuration file.

        Raises:
            FileNotFoundError: If the config file is not found.
            json.JSONDecodeError: If there is an error parsing the config file.
        """
        try:
            with open(config_path, 'r') as f:
                self.configs[config_name] = json.load(f)
        except FileNotFoundError:
            logging.error(f"Config file not found at {config_path}")
            raise
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding config file: {e}")
            raise


    def get_config_value(self, config_name: str, key: str) -> str:
        """
        Retrieves a specific configuration value.

        Args:
            key (str): The key for the configuration setting.

        Returns:
            str: The value of the configuration setting.

        Raises:
            ValueError: If the key is not found or is set to a placeholder value.
        """
        value = self.configs.get(config_name, {}).get(key)
        if value is None or value == "ENTER_YOUR_TOKEN_HERE":
            raise ValueError(f"Please set your '{key}' in the config.json file.")
        return value

class base_utils:
    """
    A utility class providing miscellaneous static methods for processing and analyzing text data,
    particularly from PDF documents and filenames. This class also includes methods for file operations.

    This class encapsulates the functionality of extracting key information from text, such as scores,
    reasoning, and IDs, locating specific data within a DataFrame based on an ID extracted from a filename,
    and reading content from files.

    Attributes:
        None (This class contains only static methods and does not maintain any state)

    Methods:
        extract_score_reasoning(text: str) -> Dict[str, Optional[str]]:
            Extracts a score and reasoning from a given text using regular expressions.

        extract_id_from_filename(filename: str) -> Optional[int]:
            Extracts an ID from a given filename based on a specified pattern.

        find_row_for_pdf(pdf_filename: str, dataframe: pd.DataFrame) -> Union[pd.Series, str]:
            Searches for a row in a DataFrame that matches an ID extracted from a PDF filename.

        read_from_file(file_path: str) -> str:
            Reads the content of a file and returns it as a string.
    """
    
    @staticmethod
    def read_from_file(file_path: str) -> str:
        """
        Reads the content of a file and returns it as a string.

        Args:
            file_path (str): The path to the file to be read.

        Returns:
            str: The content of the file.
        """
        with open(file_path, 'r') as prompt_file:
            prompt = prompt_file.read()
        return prompt
        
    @staticmethod
    def extract_id_from_filename(filename: str) -> Optional[int]:
        """
        Extracts an ID from a filename, assuming a specific format ('Id_{I}.pdf', where {I} is the ID).

        Args:
            filename (str): The filename from which to extract the ID.

        Returns:
            int: The extracted ID as an integer, or None if the pattern is not found.
        """
        # Assuming the file name is in the format 'Id_{I}.pdf', where {I} is the ID
        match = re.search(r'Id_(\d+).pdf', filename)
        if match:
            return int(match.group(1))  # Convert to integer if ID is numeric
        else:
            return None
    
    @staticmethod
    def extract_score_reasoning(text: str) -> Dict[str, Optional[str]]:
        """
        Extracts score and reasoning from a given text using regular expressions.

        Args:
            text (str): The text from which to extract the score and reasoning.

        Returns:
            dict: A dictionary containing 'score' and 'reasoning', extracted from the text.
        """
        # Define regular expression patterns for score and reasoning
        score_pattern = r"Score: (\d+)"
        reasoning_pattern = r"Reasoning: (.+)"
        
        # Extract data using regular expressions
        score_match = re.search(score_pattern, text)
        reasoning_match = re.search(reasoning_pattern, text, re.DOTALL)  # re.DOTALL allows '.' to match newlines
        
        # Extract and return the results
        extracted_data = {
            "score": score_match.group(1) if score_match else None,
            "reasoning": reasoning_match.group(1).strip() if reasoning_match else None
        }
        
        return extracted_data
            
            
    @staticmethod
    def find_row_for_pdf(pdf_filename: str, dataframe: pd.DataFrame) -> Union[pd.Series, str]:
        """
        Finds the row in a dataframe corresponding to the ID extracted from a given PDF filename.

        Args:
            pdf_filename (str): The filename of the PDF.
            dataframe (pandas.DataFrame): The dataframe in which to find the corresponding row.

        Returns:
            pandas.Series or str: The matched row from the dataframe or a message indicating
                                  that no matching row or invalid filename was found.
        """
        pdf_id = Utility.extract_id_from_filename(pdf_filename)
        if pdf_id is not None:
            # Assuming the first column contains the ID
            matched_row = dataframe[dataframe.iloc[:, 0] == pdf_id]
            if not matched_row.empty:
                return matched_row
            else:
                return "No matching row found."
        else:
            return "Invalid file name."
            
            
class PDFProcessor_Unstructured:
    """
    A class to process PDF files, providing functionalities for extracting, categorizing,
    and merging elements from a PDF file.

    This class is designed to handle unstructured PDF documents, particularly useful for
    tasks involving text extraction, categorization, and data processing within PDFs.

    Attributes:
        file_path (str): The full path to the PDF file.
        folder_path (str): The directory path where the PDF file is located.
        file_name (str): The name of the PDF file.
        texts (List[str]): A list to store extracted text chunks.
        tables (List[str]): A list to store extracted tables.
        

    Methods:
        extract_pdf_elements() -> List:
            Extracts images, tables, and text chunks from a PDF file.

        categorize_elements(raw_pdf_elements: List) -> None:
            Categorizes extracted elements from a PDF into tables and texts.

        merge_chunks() -> List[str]:
            Merges text chunks based on punctuation and character case criteria.

        should_skip_chunk(chunk: str) -> bool:
            Determines if a chunk should be skipped based on its content.

        should_merge_with_next(current_chunk: str, next_chunk: str) -> bool:
            Determines if the current chunk should be merged with the next one.

        process_pdf() -> Tuple[List[str], List[str]]:
            Processes the PDF by extracting, categorizing, and merging elements.

        process_pdf_file(uploaded_file) -> Tuple[List[str], List[str]]:
            Processes an uploaded PDF file to extract and categorize text and tables.
    """

    def __init__(self, config: Dict[str, any]):
        self.file_path = None
        self.folder_path = None
        self.file_name = None
        self.texts = []
        self.tables = []
        self.config = config if config is not None else self.default_config()
        logger.info(f"Initialized PdfProcessor_Unstructured for file: {self.file_name}")
        
    @staticmethod
    def default_config() -> Dict[str, any]:
        """
        Returns the default configuration for PDF processing.

        Returns:
            Dict[str, any]: Default configuration options.
        """
        return {
            "extract_images": False,
            "infer_table_structure": True,
            "chunking_strategy": "by_title",
            "max_characters": 10000,
            "combine_text_under_n_chars": 100,
            "strategy": "fast",
            "model_name": "yolox"
        }


    def extract_pdf_elements(self) -> List:
        """
        Extracts images, tables, and text chunks from a PDF file.

        Returns:
            List: A list of extracted elements from the PDF.
        """
        logger.info("Starting extraction of PDF elements.")
        try:
            extracted_elements = partition_pdf(
                filename=self.file_path,
                extract_images_in_pdf=False,
                infer_table_structure=True,
                chunking_strategy="by_title",
                strategy = "fast",
                max_characters=10000,
                combine_text_under_n_chars=100,
                image_output_dir_path=self.folder_path,
            )
            logger.info("Extraction of PDF elements completed successfully.")
            return extracted_elements
        except Exception as e:
            logger.error(f"Error extracting PDF elements: {e}", exc_info=True)
            raise

    def categorize_elements(self, raw_pdf_elements: List) -> None:
        """
        Categorizes extracted elements from a PDF into tables and texts.

        Args:
            raw_pdf_elements (List): A list of elements extracted from the PDF.
        """
        logger.debug("Starting categorization of PDF elements.")
        for element in raw_pdf_elements:
            element_type = str(type(element))
            if "unstructured.documents.elements.Table" in element_type:
                self.tables.append(str(element))
            elif "unstructured.documents.elements.CompositeElement" in element_type:
                self.texts.append(str(element))
                
        logger.debug("Categorization of PDF elements completed.")

    def merge_chunks(self) -> List[str]:
        """
        Merges text chunks based on punctuation and character case criteria.

        Returns:
            List[str]: A list of merged text chunks.
        """
        logger.debug("Starting merging of text chunks.")
        
        merged_chunks = []
        skip_next = False

        for i, current_chunk in enumerate(self.texts[:-1]):
            next_chunk = self.texts[i + 1]

            if self.should_skip_chunk(current_chunk):
                continue

            if self.should_merge_with_next(current_chunk, next_chunk):
                merged_chunks.append(current_chunk + " " + next_chunk)
                skip_next = True
            else:
                merged_chunks.append(current_chunk)

        if not skip_next:
            merged_chunks.append(self.texts[-1])
        
        logger.debug("Merging of text chunks completed.")

        return merged_chunks

    @staticmethod
    def should_skip_chunk(chunk: str) -> bool:
        """
        Determines if a chunk should be skipped based on its content.

        Args:
            chunk (str): The text chunk to be evaluated.

        Returns:
            bool: True if the chunk should be skipped, False otherwise.
        """
        return (chunk.lower().startswith(("figure", "fig", "table")) or
                not chunk[0].isalnum() or
                re.match(r'^\d+\.', chunk))

    @staticmethod
    def should_merge_with_next(current_chunk: str, next_chunk: str) -> bool:
        """
        Determines if the current chunk should be merged with the next one.

        Args:
            current_chunk (str): The current text chunk.
            next_chunk (str): The next text chunk.

        Returns:
            bool: True if the chunks should be merged, False otherwise.
        """
        return (current_chunk.endswith(",") or
                (current_chunk[-1].islower() and next_chunk[0].islower()))

    def extract_title_from_pdf(self, uploaded_file):
        """
        Extracts the title from a PDF file's metadata.

        This function reads the metadata of a PDF file using PyPDF2 and attempts to
        extract the title. If the title is present in the metadata, it is returned.
        Otherwise, a default message indicating that the title was not found is returned.

        Parameters:
        uploaded_file (file): A file object or a path to the PDF file from which
                          to extract the title. The file must be opened in binary mode.

        Returns:
        str: The title of the PDF file as a string. If no title is found, returns
             'Title not found'.
        """
        # Initialize PDF reader
        pdf_reader = PdfReader(uploaded_file)
        
        # Extract document information
        meta = pdf_reader.metadata
        
        # Retrieve title from document information
        title = meta.title if meta and meta.title else 'Title not found'
        return title

    def process_pdf(self) -> Tuple[List[str], List[str]]:
        """
        Processes the PDF by extracting, categorizing, and merging elements.

        Returns:
            Tuple[List[str], List[str]]: A tuple of merged text chunks and tables.
        """
        logger.info("Starting processing of the PDF.")
        try:
            raw_pdf_elements = self.extract_pdf_elements()
            self.categorize_elements(raw_pdf_elements)
            merged_chunks = self.merge_chunks()
            return merged_chunks, self.tables
        except Exception as e:
            logger.error(f"Error processing PDF: {e}", exc_info=True)
            raise
            
    def process_pdf_file(self, uploaded_file):
        """
        Process an uploaded PDF file.

        If a new file is uploaded, the previously stored file is deleted.
        The method updates the file path, processes the PDF, and returns the results.
        
        Parameters:
        uploaded_file: The new PDF file uploaded for processing.

        Returns:
        The results of processing the PDF file.
        """
        # Delete the previous file if it exists
        if self.file_path and os.path.exists(self.file_path):
            try:
                os.remove(self.file_path)
                logging.debug(f"Previous file {self.file_path} deleted.")
            except Exception as e:
                logging.warning(f"Error deleting previous file: {e}", exc_info=True)

        # Process the new file
        self.file_path = str(uploaded_file)
        self.folder_path = os.path.dirname(self.file_path)
        logging.info(f"Starting to process the PDF file: {self.file_path}")

        try:
            logging.debug(f"Processing PDF at {self.file_path}")
            results = self.process_pdf()
            title = self.extract_title_from_pdf(self.file_path)
            logging.info("PDF processing completed successfully.")
            return results, title
        except Exception as e:
            logging.error(f"Error processing PDF file: {e}", exc_info=True)
            raise


class HybridRetriever(BaseRetriever):
    """
    A hybrid retriever that combines results from vector-based and BM25 retrieval methods.
    Inherits from BaseRetriever.
    
    This class uses two different retrieval methods and merges their results to provide a
    comprehensive set of documents in response to a query. It ensures diversity in the
    retrieved documents by leveraging the strengths of both retrieval methods.

    Attributes:
        vector_retriever: An instance of a vector-based retriever.
        bm25_retriever: An instance of a BM25 retriever.

    Methods:
        __init__(vector_retriever, bm25_retriever): Initializes the HybridRetriever with vector and BM25 retrievers.
        _retrieve(query, **kwargs): Performs the retrieval operation by combining results from both retrievers.
        _combine_results(bm25_nodes, vector_nodes): Combines and de-duplicates the results from both retrievers.
    """

    def __init__(self, vector_retriever, bm25_retriever):
        super().__init__()
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        logger.info("HybridRetriever initialized with vector and BM25 retrievers.")

    def _retrieve(self, query: str, **kwargs) -> List:
        """
        Retrieves and combines results from both vector and BM25 retrievers.

        Args:
            query: The query string for document retrieval.
            **kwargs: Additional keyword arguments for retrieval.

        Returns:
            List: Combined list of unique nodes retrieved from both methods.
        """
        logger.info(f"Retrieving documents for query: {query}")
        try:
            bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
            vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
            combined_nodes = self._combine_results(bm25_nodes, vector_nodes)

            logger.info(f"Retrieved {len(combined_nodes)} unique nodes combining vector and BM25 retrievers.")
            return combined_nodes
        except Exception as e:
            logger.error(f"Error in retrieval: {e}")
            raise

    @staticmethod
    def _combine_results(bm25_nodes: List, vector_nodes: List) -> List:
        """
        Combines and de-duplicates results from BM25 and vector retrievers.

        Args:
            bm25_nodes: Nodes retrieved from BM25 retriever.
            vector_nodes: Nodes retrieved from vector retriever.

        Returns:
            List: Combined list of unique nodes.
        """
        node_ids: Set = set()
        combined_nodes = []

        for node in bm25_nodes + vector_nodes:
            if node.node_id not in node_ids:
                combined_nodes.append(node)
                node_ids.add(node.node_id)

        return combined_nodes

        

class PDFQueryEngine:
    """
    A class to handle the process of setting up a query engine and performing queries on PDF documents.

    This class encapsulates the functionality of creating prompt templates, embedding models, service contexts,
    indexes, hybrid retrievers, response synthesizers, and executing queries on the set up engine.

    Attributes:
        documents (List): A list of documents to be indexed.
        llm (Language Model): The language model to be used for embeddings and queries.
        qa_prompt_tmpl (str): Template for creating query prompts.
        queries (List[str]): List of queries to be executed.

    Methods:
        setup_query_engine(): Sets up the query engine with all necessary components.
        execute_queries(): Executes the predefined queries and prints the results.
    """

    def __init__(self, documents: List[Any], llm: Any, embed_model: Any, qa_prompt_tmpl: Any):
        
        self.documents = documents
        self.llm = llm
        self.embed_model = embed_model
        self.qa_prompt_tmpl = qa_prompt_tmpl
        self.base_utils = base_utils()
        self.config_manager = ConfigManager()
        
        
    
        logger.info("PDFQueryEngine initialized.")

    def format_example(self, example):
        """
        Formats a few-shot example into a string.

        Args:
            example (dict): A dictionary containing 'query', 'score', and 'reasoning' for the few-shot example.

        Returns:
            str: Formatted few-shot example text.
        """
        return "Example:\nQuery: {}\nScore: {}\nReasoning: {}\n".format(
            example['query'], example['score'], example['reasoning']
        )


    def setup_query_engine(self):
        """
        Sets up the query engine by initializing and configuring the embedding model, service context, index,
        hybrid retriever (combining vector and BM25 retrievers), and the response synthesizer.

        Args:
            embed_model: The embedding model to be used.
            service_context: The context for providing services to the query engine.
            index: The index used for storing and retrieving documents.
            hybrid_retriever: The retriever that combines vector and BM25 retrieval methods.
            response_synthesizer: The synthesizer for generating responses to queries.

        Returns:
            Any: The configured query engine.
        """
        client = qdrant_client.QdrantClient(
            # you can use :memory: mode for fast and light-weight experiments,
            # it does not require to have Qdrant deployed anywhere
            # but requires qdrant-client >= 1.1.1
            location=":memory:"
            # otherwise set Qdrant instance address with:
            # uri="http://<host>:<port>"
            # set API KEY for Qdrant Cloud
            # api_key="<qdrant-api-key>",
            )
        try:
            logger.info("Initializing the service context for query engine setup.")
            service_context = ServiceContext.from_defaults(llm=self.llm, embed_model=self.embed_model)
            vector_store = QdrantVectorStore(client=client, collection_name="med_library")
            storage_context = StorageContext.from_defaults(vector_store=vector_store)

            logger.info("Creating an index from documents.")
            index = VectorStoreIndex.from_documents(documents=self.documents, storage_context=storage_context, service_context=service_context)
            nodes = service_context.node_parser.get_nodes_from_documents(self.documents)

            logger.info("Setting up vector and BM25 retrievers.")
            vector_retriever = index.as_retriever(similarity_top_k=3)
            bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=3)
            hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
        
            logger.info("Configuring the response synthesizer with the prompt template.")
            qa_prompt = PromptTemplate(self.qa_prompt_tmpl)
            response_synthesizer = get_response_synthesizer(
                service_context=service_context,
                text_qa_template=qa_prompt,
                response_mode="compact",
            )

            logger.info("Assembling the query engine with reranker and synthesizer.")
            reranker = SentenceTransformerRerank(top_n=3, model="BAAI/bge-reranker-base") 
            query_engine = RetrieverQueryEngine.from_args(
                retriever=hybrid_retriever,
                node_postprocessors=[reranker],
                response_synthesizer=response_synthesizer,
            )

            logger.info("Query engine setup complete.")
            return query_engine
        except Exception as e:
            logger.error(f"Error during query engine setup: {e}")
            raise

    def evaluate_with_llm(self, reg_result: Any, peer_result: Any, guidelines_result: Any, queries: List[str]) -> Tuple[int, List[int], int, float, List[str]]:
        """
        Evaluate documents using a language model based on various criteria.
        Args:
            reg_result (Any): Result related to registration.
            peer_result (Any): Result related to peer review.
            guidelines_result (Any): Result related to following guidelines.
            queries (List[str]): A list of queries to be processed.
        Returns:
            Tuple[int, List[int], int, float, List[str]]: A tuple containing the total score, a list of scores per criteria.
        """
    
        logger.info("Starting evaluation with LLM.")
        self.config_manager.load_config("few_shot", "few_shot.json")
        query_engine = self.setup_query_engine()

        total_score = 0
        criteria_met = 0
        reasoning = []
    
        for j, query in enumerate(queries):
            # Handle special cases based on the value of j and other conditions
            if j == 1 and reg_result:
                extracted_data = {"score": 1, "reasoning": reg_result[0]}
            elif j == 2 and guidelines_result:
                extracted_data = {"score": 1, "reasoning": "The article is published in a journal following EQUATOR-NETWORK reporting guidelines"}
            elif j == 8 and (guidelines_result or peer_result):
                extracted_data = {"score": 1, "reasoning": "The article is published in a peer-reviewed journal."}
            else:
                
                # Execute the query
                result = query_engine.query(query).response
                extracted_data = self.base_utils.extract_score_reasoning(result)


            # Validate and accumulate the scores
            extracted_data_score = 0 if extracted_data.get("score") is None else int(extracted_data.get("score"))
            if extracted_data_score > 0:
                criteria_met += 1
            reasoning.append(extracted_data["reasoning"])
            total_score += extracted_data_score

        score_percentage = (float(total_score) / len(queries)) * 100
        logger.info("Evaluation completed.")
        return total_score, criteria_met, score_percentage, reasoning
        


class MixtralLLM(CustomLLM):
    """
    A custom language model class for interfacing with the Hugging Face API, specifically using the Mixtral model.

    Attributes:
        context_window (int): Number of tokens used for context during inference.
        num_output (int): Number of tokens to generate as output.
        temperature (float): Sampling temperature for token generation.
        model_name (str): Name of the model on Hugging Face's model hub.
        api_key (str): API key for authenticating with the Hugging Face API.

    Methods:
        metadata: Retrieves metadata about the model.
        do_hf_call: Makes an API call to the Hugging Face model.
        complete: Generates a complete response for a given prompt.
        stream_complete: Streams a series of token completions for a given prompt.
    """
    context_window: int = Field(..., description="Number of tokens used for context during inference.")
    num_output: int = Field(..., description="Number of tokens to generate as output.")
    temperature: float = Field(..., description="Sampling temperature for token generation.")
    model_name: str = Field(..., description="Name of the model on Hugging Face's model hub.")
    api_key: str = Field(..., description="API key for authenticating with the Hugging Face API.")


    @property
    def metadata(self) -> LLMMetadata:
        """
        Retrieves metadata for the Mixtral LLM.

        Returns:
            LLMMetadata: An object containing metadata such as context window, number of outputs, and model name.
        """
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.num_output,
            model_name=self.model_name,
        )

    def do_hf_call(self, prompt: str) -> str:
        """
        Makes an API call to the Hugging Face model and retrieves the generated response.

        Args:
            prompt (str): The input prompt for the model.

        Returns:
            str: The text generated by the model in response to the prompt.
        
        Raises:
            Exception: If the API call fails or returns an error.
        """
        data = {
            "inputs": prompt,
            "parameters": {"Temperature": self.temperature}
        }

        # Makes a POST request to the Hugging Face API to get the model's response
        response = requests.post(
            f'https://api-inference.huggingface.co/models/{self.model_name}',
            headers={
                'authorization': f'Bearer {self.api_key}',
                'content-type': 'application/json',
            },
            json=data,
            stream=True
        )

        # Checks for a successful response and parses the generated text
        if response.status_code != 200 or not response.json() or 'error' in response.json():
            print(f"Error: {response}")
            return "Unable to answer for technical reasons."
        full_txt = response.json()[0]['generated_text']
        # Finds the section of the text following the context separator
        offset = full_txt.find("---------------------")
        ss = full_txt[offset:]
        # Extracts the actual answer from the response
        offset = ss.find("Answer:")
        return ss[offset+7:].strip()

    
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        """
        Generates a complete response for a given prompt using the Hugging Face API.

        Args:
            prompt (str): The input prompt for the model.
            **kwargs: Additional keyword arguments for the completion.

        Returns:
            CompletionResponse: The complete response from the model.
        """
        response = self.do_hf_call(prompt)
        return CompletionResponse(text=response)

    
    @llm_completion_callback()
    def stream_complete(
            self, prompt: str, **kwargs: Any
    ) -> CompletionResponseGen:
        """
        Streams a series of token completions as a response for the given prompt.

        This method is useful for streaming responses where each token is generated sequentially.

        Args:
            prompt (str): The input prompt for the model.
            **kwargs: Additional keyword arguments for the streaming completion.

        Yields:
            CompletionResponseGen: A generator yielding each token in the completion response.
        """
        # Yields a stream of tokens as the completion response for the given prompt
        response = ""
        for token in self.do_hf_call(prompt):
            response += token
            yield CompletionResponse(text=response, delta=token)



class KeywordSearch():
    def __init__(self, chunks):
        self.chunks = chunks
    
    def find_journal_name(self, response: str, journal_list: list) -> str:
        """
        Searches for a journal name in a given response string.

        This function iterates through a list of known journal names and checks if any of these
        names are present in the response string. It returns the first journal name found in the
        response. If no journal names from the list are found in the response, a default message
        indicating that the journal name was not found is returned.

        Args:
            response (str): The response string to search for a journal name.
            journal_list (list): A list of journal names to search within the response.

        Returns:
            str: The first journal name found in the response, or a default message if no journal name is found.
        """
        response_lower = response.lower()
        for journal in journal_list:
            journal_lower = journal.lower()
            
            if journal_lower in response_lower:
                print(journal_lower,response_lower)
                return True
        
        return False

    def check_registration(self):
        """
        Check chunks of text for various registration numbers or URLs of registries.
        Returns the sentence containing a registration number, or if not found,
        returns chunks containing registry URLs.

        Args:
        chunks (list of str): List of text chunks to search.

        Returns:
        list of str: List of matching sentences or chunks, or an empty list if no matches are found.
        """

        # Patterns for different registration types
        patterns = {
            "NCT": r"\(?(NCT#?\s*(No\s*)?)(\d{8})\)?",
            "ISRCTN": r"(ISRCTN\d{8})",
            "EudraCT": r"(\d{4}-\d{6}-\d{2})",
            "UMIN-CTR": r"(UMIN\d{9})",
            "CTRI": r"(CTRI/\d{4}/\d{2}/\d{6})"
        }

        # Registry URLs
        registry_urls = [
            "www.anzctr.org.au",
            "anzctr.org.au",
            "www.clinicaltrials.gov",
            "clinicaltrials.gov",
            "www.ISRCTN.org",
            "ISRCTN.org",
            "www.umin.ac.jp/ctr/index/htm",
            "umin.ac.jp/ctr/index/htm",
            "www.onderzoekmetmensen.nl/en",
            "onderzoekmetmensen.nl/en",
            "eudract.ema.europa.eu",
            "www.eudract.ema.europa.eu"
            ]
    

        # Check each chunk for registration numbers
        for chunk in self.chunks:
            # Split chunk into sentences
            sentences = re.split(r'(?<=[.!?]) +', chunk)

            # Check each sentence for any registration number
            for sentence in sentences:
                for pattern in patterns.values():
                    if re.search(pattern, sentence):
                        return [sentence]  # Return immediately if a registration number is found

        # If no registration number found, check for URLs in chunks
        matching_chunks = []
        for chunk in self.chunks:
            if any(url in chunk for url in registry_urls):
                matching_chunks.append(chunk)

        return matching_chunks

    

class StringExtraction():

    """
    A class to handle the the process of extraction of query string from complete LLM responses.

    This class encapsulates the functionality of extracting original ground truth from a labelled data csv and query strings from responses. Please note that
    LLMs may generate different formatted answers based on different models or different prompting technique. In such cases, extract_original_prompt may not give 
    satisfactory results. Best case scenario will be write your own string extraction method in such cases.

    
    Methods:
        extract_original_prompt(): 
        extraction_ground_truth(): 
    """

    def extract_original_prompt(self,result):
        r1 = result.response.strip().split("\n")
        binary_response = ""
        explanation_response = ""
        for r in r1:
            if binary_response == "" and (r.find("Yes") >= 0 or r.find("No") >= 0):
                binary_response = r
            elif r.find("Reasoning:") >= 0:
                cut = r.find(":")
                explanation_response += r[cut+1:].strip()
            
        return binary_response,explanation_response

    def extraction_ground_truth(self,paper_name,labelled_data):
        id = int(paper_name[paper_name.find("_")+1:paper_name.find(".pdf")])
        id_row = labelled_data[labelled_data["id"] == id]
        ground_truth = id_row.iloc[:,2:11].values.tolist()[0]
        binary_ground_truth = []
        explanation_ground_truth = []
        for g in ground_truth:
            if len(g) > 0:
                binary_ground_truth.append("Yes")
                explanation_ground_truth.append(g)
            else:
                binary_ground_truth.append("No")
                explanation_ground_truth.append("The article does not provide any relevant information.")
        return binary_ground_truth,explanation_ground_truth
    

  
class EvaluationMetrics():
    """

    This class encapsulates the evaluation methods that have been used in the project.

    Attributes:
        explanation_response = a list of detailed response from the LLM model corresponding to each query
        explanation_ground_truth = the list of ground truth corresponding to each query

    Methods:
        metric_cosine_similairty(): Sets up the query engine with all necessary components.
        metric_rouge(): Executes the predefined queries and prints the results.
        metric_binary_accuracy():
    """


    def __init__(self,explanation_response,explanation_ground_truth,embedding_model):
        self.explanation_response = explanation_response
        self.explanation_ground_truth = explanation_ground_truth
        self.embedding_model = embedding_model
    
    def metric_cosine_similarity(self):
        ground_truth_embedding = self.embedding_model.encode(self.explanation_ground_truth)
        explanation_response_embedding = self.embedding_model.encode(self.explanation_response)
        return np.diag(cosine_similarity(ground_truth_embedding,explanation_response_embedding))
    
    def metric_rouge(self):
        rouge = evaluate.load("rouge")
        results = rouge.compute(predictions = self.explanation_response,references = self.explanation_ground_truth)
        return results 

    def binary_accuracy(self,binary_response,binary_ground_truth):
        count = 0
        if len(binary_response) != len(binary_ground_truth):
            return "Arrays which are to be compared has different lengths."
        else:
            for i in range(len(binary_response)):
                if binary_response[i] == binary_ground_truth[i]:
                    count += 1
            return np.round(count/len(binary_response),2)