Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
165ff3a
1
Parent(s):
3fa4157
Add zeroGPU
Browse files
app.py
CHANGED
@@ -14,11 +14,18 @@ from datasets import Dataset
|
|
14 |
from datasets.search import FaissIndex
|
15 |
import faiss
|
16 |
from huggingface_hub import snapshot_download
|
|
|
|
|
17 |
import gradio as gr
|
18 |
import requests
|
19 |
from sentence_transformers import SentenceTransformer
|
20 |
import torch
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class IndexParameters(TypedDict):
|
24 |
recall: float # in this case 10-recall@10
|
@@ -267,11 +274,14 @@ def main():
|
|
267 |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
268 |
model.compile(mode="reduce-overhead")
|
269 |
|
|
|
|
|
|
|
|
|
|
|
270 |
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
271 |
def search(query: list[str]) -> tuple[list[str]]:
|
272 |
-
query_embedding =
|
273 |
-
query, prompt_name, normalize_embeddings=normalize
|
274 |
-
)
|
275 |
distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
|
276 |
|
277 |
faiss_ids_flat = list(chain(*faiss_ids))
|
|
|
14 |
from datasets.search import FaissIndex
|
15 |
import faiss
|
16 |
from huggingface_hub import snapshot_download
|
17 |
+
import numpy as np
|
18 |
+
import numpy.typing as npt
|
19 |
import gradio as gr
|
20 |
import requests
|
21 |
from sentence_transformers import SentenceTransformer
|
22 |
import torch
|
23 |
|
24 |
+
try:
|
25 |
+
import spaces
|
26 |
+
except ImportError:
|
27 |
+
spaces = None
|
28 |
+
|
29 |
|
30 |
class IndexParameters(TypedDict):
|
31 |
recall: float # in this case 10-recall@10
|
|
|
274 |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
275 |
model.compile(mode="reduce-overhead")
|
276 |
|
277 |
+
def encode(query: list[str]) -> npt.NDArray[np.float16 | np.float32]:
|
278 |
+
return model.encode(query, prompt_name, normalize_embeddings=normalize)
|
279 |
+
if spaces:
|
280 |
+
encode = spaces.GPU(encode)
|
281 |
+
|
282 |
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
283 |
def search(query: list[str]) -> tuple[list[str]]:
|
284 |
+
query_embedding = encode(query)
|
|
|
|
|
285 |
distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
|
286 |
|
287 |
faiss_ids_flat = list(chain(*faiss_ids))
|