Spaces:
Build error
Build error
Commit
·
9a4e478
1
Parent(s):
8de88bd
Minor fixes
Browse files- backend/query_llm.py +6 -7
- backend/semantic_search.py +20 -7
backend/query_llm.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
import requests
|
4 |
from os import getenv
|
5 |
|
|
|
|
|
6 |
|
7 |
API_URL = getenv('API_URL')
|
8 |
BEARER = getenv('BEARER')
|
9 |
|
10 |
-
|
11 |
headers = {
|
12 |
-
|
13 |
-
|
14 |
-
}
|
15 |
|
16 |
|
17 |
def call_jais(payload):
|
@@ -26,7 +25,7 @@ def call_jais(payload):
|
|
26 |
|
27 |
|
28 |
def generate(prompt: str):
|
29 |
-
payload = {'inputs': '', 'prompt':prompt}
|
30 |
response = call_jais(payload)
|
31 |
return response
|
32 |
|
|
|
1 |
import os
|
|
|
|
|
2 |
from os import getenv
|
3 |
|
4 |
+
import gradio as gr
|
5 |
+
import requests
|
6 |
|
7 |
API_URL = getenv('API_URL')
|
8 |
BEARER = getenv('BEARER')
|
9 |
|
|
|
10 |
headers = {
|
11 |
+
"Authorization": f"Bearer {BEARER}",
|
12 |
+
"Content-Type": "application/json"
|
13 |
+
}
|
14 |
|
15 |
|
16 |
def call_jais(payload):
|
|
|
25 |
|
26 |
|
27 |
def generate(prompt: str):
|
28 |
+
payload = {'inputs': '', 'prompt': prompt}
|
29 |
response = call_jais(payload)
|
30 |
return response
|
31 |
|
backend/semantic_search.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import logging
|
2 |
-
from pathlib import Path
|
3 |
import time
|
|
|
4 |
|
5 |
import lancedb
|
6 |
from sentence_transformers import SentenceTransformer
|
|
|
7 |
import spaces
|
8 |
|
9 |
|
@@ -17,7 +18,7 @@ start_time = time.perf_counter()
|
|
17 |
proj_dir = Path(__file__).parents[1]
|
18 |
|
19 |
# Log the time taken to load the QdrantDocumentStore
|
20 |
-
db = lancedb.connect(proj_dir/"lancedb")
|
21 |
tbl = db.open_table('arabic-wiki')
|
22 |
lancedb_loading_time = time.perf_counter() - start_time
|
23 |
logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
|
@@ -25,23 +26,35 @@ logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
|
|
25 |
# Start the timer for loading the EmbeddingRetriever
|
26 |
start_time = time.perf_counter()
|
27 |
|
28 |
-
name="sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
|
29 |
-
|
|
|
|
|
30 |
|
31 |
# used for both training and querying
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
@spaces.GPU
|
33 |
def embed_func(query):
|
34 |
-
return
|
|
|
35 |
|
36 |
def vector_search(query_vector, top_k):
|
37 |
return tbl.search(query_vector).limit(top_k).to_list()
|
38 |
|
|
|
39 |
def retriever(query, top_k=3):
|
40 |
-
query_vector =
|
41 |
documents = vector_search(query_vector, top_k)
|
42 |
return documents
|
43 |
|
44 |
|
45 |
# Log the time taken to load the EmbeddingRetriever
|
46 |
retriever_loading_time = time.perf_counter() - start_time
|
47 |
-
logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")
|
|
|
1 |
import logging
|
|
|
2 |
import time
|
3 |
+
from pathlib import Path
|
4 |
|
5 |
import lancedb
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
+
|
8 |
import spaces
|
9 |
|
10 |
|
|
|
18 |
proj_dir = Path(__file__).parents[1]
|
19 |
|
20 |
# Log the time taken to load the QdrantDocumentStore
|
21 |
+
db = lancedb.connect(proj_dir / "lancedb")
|
22 |
tbl = db.open_table('arabic-wiki')
|
23 |
lancedb_loading_time = time.perf_counter() - start_time
|
24 |
logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
|
|
|
26 |
# Start the timer for loading the EmbeddingRetriever
|
27 |
start_time = time.perf_counter()
|
28 |
|
29 |
+
name = "sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
|
30 |
+
st_model_gpu = SentenceTransformer(name, device='mps')
|
31 |
+
st_model_cpu = SentenceTransformer(name, device='cpu')
|
32 |
+
|
33 |
|
34 |
# used for both training and querying
|
35 |
+
def call_embed_func(query):
|
36 |
+
try:
|
37 |
+
return embed_func(query)
|
38 |
+
except:
|
39 |
+
logger.warning(f'Using CPU')
|
40 |
+
return st_model_cpu.encode(query)
|
41 |
+
|
42 |
+
|
43 |
@spaces.GPU
|
44 |
def embed_func(query):
|
45 |
+
return st_model_gpu.encode(query)
|
46 |
+
|
47 |
|
48 |
def vector_search(query_vector, top_k):
|
49 |
return tbl.search(query_vector).limit(top_k).to_list()
|
50 |
|
51 |
+
|
52 |
def retriever(query, top_k=3):
|
53 |
+
query_vector = call_embed_func(query)
|
54 |
documents = vector_search(query_vector, top_k)
|
55 |
return documents
|
56 |
|
57 |
|
58 |
# Log the time taken to load the EmbeddingRetriever
|
59 |
retriever_loading_time = time.perf_counter() - start_time
|
60 |
+
logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")
|