File size: 3,729 Bytes
bd3532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
from typing import List, Dict, Union, Optional, Any
import numpy as np

from .embedding_provider import EmbeddingProvider
from .database.annoydb import AnnoyDB
from .keyword_search_provider import KeywordSearchProvider

class HybridSearch:
    def __init__(
        self,
        embedding_provider: EmbeddingProvider,
        documents: List[str] = None,
        ann_filepath: Optional[str] = None,
        semantic_weight: float = 0.7, 
        keyword_weight: float = 0.3
    ) -> None:
        self.embedding_provider = embedding_provider
        self.documents = documents
        
        if ann_filepath and os.path.exists(ann_filepath):
            self.index = AnnoyDB
        self.embeddings = self.embedding_provider.embed_documents(documents)
        
        self.vector_db = AnnoyDB(
            embedding_dim=self.embeddings.shape[1]
        )
        
        for emb, doc in zip(self.embeddings, documents):
            self.vector_db.add_data(emb, doc)
        self.vector_db.build()
        
        # Keyword Search Setup
        self.keyword_search = KeywordSearchProvider(documents)
        
        # Weights for hybrid search
        self.semantic_weight = semantic_weight
        self.keyword_weight = keyword_weight
        
        self.documents = documents
        
    def hybrid_search(self, query: str, top_k: int = 5) -> List[Dict[str, Union[str, float]]]:
        # Embed query
        query_embedding = self.embedding_provider.embed_query(query)
        
        # Perform semantic search
        semantic_results = self.vector_db.search(query_embedding, top_k)
        
        # Perform keyword search
        keyword_results = self.keyword_search.search(query, top_k)
        
        # Combine results with weighted scoring
        combined_results = {}
        
        for result in semantic_results:
            doc = result['document']
            combined_results[doc] = {
                'semantic_score': result['score'] * self.semantic_weight,
                'keyword_score': 0,
                'hybrid_score': result['score'] * self.semantic_weight
            }
        
        for result in keyword_results:
            doc = result['document']
            if doc in combined_results:
                combined_results[doc]['keyword_score'] = result['score'] * self.keyword_weight
                combined_results[doc]['hybrid_score'] += result['score'] * self.keyword_weight
            else:
                combined_results[doc] = {
                    'semantic_score': 0,
                    'keyword_score': result['score'] * self.keyword_weight,
                    'hybrid_score': result['score'] * self.keyword_weight
                }
        
        # Sort and return top results
        sorted_results = sorted(
            [
                {**{'document': doc}, **scores} 
                for doc, scores in combined_results.items()
            ], 
            key=lambda x: x['hybrid_score'], 
            reverse=True
        )
        
        return sorted_results[:top_k]
    
    def set_weights(self, semantic_weight: float, keyword_weight: float):
        """
        Dynamically update search weights.
        
        Args:
            semantic_weight: New weight for semantic search
            keyword_weight: New weight for keyword search
        """
        if not (0 <= semantic_weight <= 1 and 0 <= keyword_weight <= 1):
            raise ValueError("Weights must be between 0 and 1")
        
        if not np.isclose(semantic_weight + keyword_weight, 1.0):
            raise ValueError("Semantic and keyword weights must sum to 1.0")
        
        self.semantic_weight = semantic_weight
        self.keyword_weight = keyword_weight