Spaces:
Sleeping
Sleeping
SebastianSchramm
commited on
Commit
·
fca97ef
1
Parent(s):
053ffc5
use models
Browse files- data/paris-2024-faq.json +0 -0
- server.py +98 -3
data/paris-2024-faq.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
server.py
CHANGED
@@ -1,14 +1,32 @@
|
|
1 |
import logging
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from fastapi import FastAPI
|
4 |
from pydantic import BaseModel
|
|
|
|
|
|
|
5 |
|
6 |
|
|
|
|
|
7 |
logging.basicConfig()
|
8 |
logger = logging.getLogger(__name__)
|
9 |
logger.setLevel(logging.INFO)
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class InputLoad(BaseModel):
|
13 |
question: str
|
14 |
|
@@ -17,7 +35,31 @@ class ResponseLoad(BaseModel):
|
|
17 |
answer: str
|
18 |
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
@app.get("/health")
|
@@ -26,5 +68,58 @@ def health_check():
|
|
26 |
|
27 |
|
28 |
@app.post("/answer/")
|
29 |
-
async def receive(input_load: InputLoad) -> ResponseLoad:
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import json
|
3 |
+
from contextlib import asynccontextmanager
|
4 |
+
from typing import Any, List, Tuple
|
5 |
+
import random
|
6 |
|
7 |
from fastapi import FastAPI
|
8 |
from pydantic import BaseModel
|
9 |
+
from FlagEmbedding import BGEM3FlagModel, FlagReranker
|
10 |
+
from starlette.requests import Request
|
11 |
+
import torch
|
12 |
|
13 |
|
14 |
+
random.seed(42)
|
15 |
+
|
16 |
logging.basicConfig()
|
17 |
logger = logging.getLogger(__name__)
|
18 |
logger.setLevel(logging.INFO)
|
19 |
|
20 |
|
21 |
+
def get_data(model):
|
22 |
+
with open("data/paris-2024-faq.json") as f:
|
23 |
+
data = json.load(f)
|
24 |
+
data = [it for it in data if it['lang'] == 'en']
|
25 |
+
questions = [it['label'] for it in data]
|
26 |
+
q_embeddings = model[0].encode(questions, return_dense=False, return_sparse=False, return_colbert_vecs=True)
|
27 |
+
return q_embeddings['colbert_vecs'], questions, [it['body'] for it in data]
|
28 |
+
|
29 |
+
|
30 |
class InputLoad(BaseModel):
|
31 |
question: str
|
32 |
|
|
|
35 |
answer: str
|
36 |
|
37 |
|
38 |
+
class ML(BaseModel):
|
39 |
+
retriever: Any
|
40 |
+
ranker: Any
|
41 |
+
data: Tuple[List[Any], List[str], List[str]]
|
42 |
+
|
43 |
+
|
44 |
+
def load_models(app: FastAPI) -> FastAPI:
|
45 |
+
retriever=BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) ,
|
46 |
+
ranker=FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
|
47 |
+
ml = ML(
|
48 |
+
retriever=retriever,
|
49 |
+
ranker=ranker,
|
50 |
+
data=get_data(retriever)
|
51 |
+
)
|
52 |
+
app.ml = ml
|
53 |
+
return app
|
54 |
+
|
55 |
+
|
56 |
+
@asynccontextmanager
|
57 |
+
async def lifespan(app: FastAPI):
|
58 |
+
app = load_models(app=app)
|
59 |
+
yield
|
60 |
+
|
61 |
+
|
62 |
+
app = FastAPI(lifespan=lifespan)
|
63 |
|
64 |
|
65 |
@app.get("/health")
|
|
|
68 |
|
69 |
|
70 |
@app.post("/answer/")
|
71 |
+
async def receive(input_load: InputLoad, request: Request) -> ResponseLoad:
|
72 |
+
ml: ML = request.app.ml
|
73 |
+
candidate_indices, candidate_scores = get_candidates(input_load.question, ml)
|
74 |
+
answer_candidate, rank_score, retriever_score = rerank_candidates(input_load.question, candidate_indices, candidate_scores, ml)
|
75 |
+
answer = get_final_answer(answer_candidate, retriever_score)
|
76 |
+
return ResponseLoad(answer=answer)
|
77 |
+
|
78 |
+
|
79 |
+
def get_candidates(question, ml, topk=5):
|
80 |
+
question_emb = ml.retriever[0].encode([question], return_dense=False, return_sparse=False, return_colbert_vecs=True)
|
81 |
+
question_emb = question_emb['colbert_vecs'][0]
|
82 |
+
scores = [ml.retriever[0].colbert_score(question_emb, faq_emb) for faq_emb in ml.data[0]]
|
83 |
+
scores_tensor = torch.stack(scores)
|
84 |
+
top_values, top_indices = torch.topk(scores_tensor, topk)
|
85 |
+
return top_indices.tolist(), top_values.tolist()
|
86 |
+
|
87 |
+
|
88 |
+
def rerank_candidates(question, indices, values, ml):
|
89 |
+
candidate_answers = [ml.data[2][_ind] for _ind in indices]
|
90 |
+
scores = ml.ranker.compute_score([[question, it] for it in candidate_answers])
|
91 |
+
rank_score = max(scores)
|
92 |
+
rank_ind = scores.index(rank_score)
|
93 |
+
retriever_score = values[rank_ind]
|
94 |
+
return candidate_answers[rank_ind], rank_score, retriever_score
|
95 |
+
|
96 |
+
|
97 |
+
def get_final_answer(answer, retriever_score):
|
98 |
+
logger.info(f"Retriever score: {retriever_score}")
|
99 |
+
if retriever_score < 0.65:
|
100 |
+
# nothing relevant found!
|
101 |
+
return random.sample(NOT_FOUND_ANSWERS, k=1)[0]
|
102 |
+
elif retriever_score < 0.8:
|
103 |
+
# might be relevant, but let's be careful
|
104 |
+
return f"{random.sample(ROUGH_MATCH_INTROS, k=1)[0]}\n{answer}"
|
105 |
+
else:
|
106 |
+
# good match
|
107 |
+
return f"{random.sample(GOOD_MATCH_INTROS, k=1)[0]}\n{answer}\n{random.sample(GOOD_MATCH_ENDS, k=1)[0]}"
|
108 |
+
|
109 |
+
|
110 |
+
NOT_FOUND_ANSWERS = [
|
111 |
+
"I'm sorry, but I couldn't find any information related to your question in my knowledge base.",
|
112 |
+
"Apologies, but I don't have the information you're looking for at the moment.",
|
113 |
+
"I’m sorry, I couldn’t locate any relevant details in my current data.",
|
114 |
+
"Unfortunately, I wasn't able to find an answer to your query. Can I help with something else?",
|
115 |
+
"I'm afraid I don't have the information you need right now. Please feel free to ask another question.",
|
116 |
+
"Sorry, I couldn't find anything that matches your question in my knowledge base.",
|
117 |
+
"I apologize, but I wasn't able to retrieve information related to your query.",
|
118 |
+
"I'm sorry, but it looks like I don't have an answer for that. Is there anything else I can assist with?",
|
119 |
+
"Regrettably, I couldn't find the information you requested. Can I help you with anything else?",
|
120 |
+
"I’m sorry, but I don't have the details you're seeking in my knowledge database."
|
121 |
+
]
|
122 |
+
|
123 |
+
GOOD_MATCH_INTROS = ["Super!"]
|
124 |
+
GOOD_MATCH_ENDS = ["Hopes this helps!"]
|
125 |
+
ROUGH_MATCH_INTROS = ["Not sure if that answers your question!"]
|