ppsingh commited on
Commit
71aaf00
1 Parent(s): 34798a4

Create retriever.py

Browse files
Files changed (1) hide show
  1. auditqa/retriever.py +57 -0
auditqa/retriever.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client.http import models as rest
2
+ from auditqa.process_chunks import getconfig
3
+ from langchain.retrievers import ContextualCompressionRetriever
4
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
5
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
6
+ import logging
7
+
8
+ model_config = getconfig("model_params.cfg")
9
+
10
+ def create_filter(reports:list = [],sources:str =None,
11
+ subtype:str =None,year:str =None):
12
+ if len(reports) == 0:
13
+ print("defining filter for:{}:{}:{}".format(sources,subtype,year))
14
+ filter=rest.Filter(
15
+ must=[rest.FieldCondition(
16
+ key="metadata.source",
17
+ match=rest.MatchValue(value=sources)
18
+ ),
19
+ rest.FieldCondition(
20
+ key="metadata.subtype",
21
+ match=rest.MatchValue(value=subtype)
22
+ ),
23
+ rest.FieldCondition(
24
+ key="metadata.year",
25
+ match=rest.MatchAny(any=year)
26
+ ),])
27
+ else:
28
+ print("defining filter for allreports:",reports)
29
+ filter=rest.Filter(
30
+ must=[
31
+ rest.FieldCondition(
32
+ key="metadata.filename",
33
+ match=rest.MatchAny(any=reports)
34
+ )])
35
+
36
+ return filter
37
+
38
+
39
+ def get_context(vectorstore,query,reports,sources,subtype,year):
40
+ # create metadata filter
41
+ filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year)
42
+
43
+ # getting context
44
+ retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
45
+ search_kwargs={"score_threshold": 0.6,
46
+ "k": int(model_config.get('retriever','TOP_K')),
47
+ "filter":filter})
48
+ # re-ranking the retrieved results
49
+ model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
50
+ compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
51
+ compression_retriever = ContextualCompressionRetriever(
52
+ base_compressor=compressor, base_retriever=retriever
53
+ )
54
+ context_retrieved = compression_retriever.invoke(query)
55
+ print(f"retrieved paragraphs:{len(context_retrieved)}")
56
+
57
+ return context_retrieved