colonelwatch commited on
Commit
165ff3a
·
1 Parent(s): 3fa4157

Add zeroGPU

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -14,11 +14,18 @@ from datasets import Dataset
14
  from datasets.search import FaissIndex
15
  import faiss
16
  from huggingface_hub import snapshot_download
 
 
17
  import gradio as gr
18
  import requests
19
  from sentence_transformers import SentenceTransformer
20
  import torch
21
 
 
 
 
 
 
22
 
23
  class IndexParameters(TypedDict):
24
  recall: float # in this case 10-recall@10
@@ -267,11 +274,14 @@ def main():
267
  print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
268
  model.compile(mode="reduce-overhead")
269
 
 
 
 
 
 
270
  # function signature: (expanded tuple of input batches) -> tuple of output batches
271
  def search(query: list[str]) -> tuple[list[str]]:
272
- query_embedding = model.encode(
273
- query, prompt_name, normalize_embeddings=normalize
274
- )
275
  distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
276
 
277
  faiss_ids_flat = list(chain(*faiss_ids))
 
14
  from datasets.search import FaissIndex
15
  import faiss
16
  from huggingface_hub import snapshot_download
17
+ import numpy as np
18
+ import numpy.typing as npt
19
  import gradio as gr
20
  import requests
21
  from sentence_transformers import SentenceTransformer
22
  import torch
23
 
24
+ try:
25
+ import spaces
26
+ except ImportError:
27
+ spaces = None
28
+
29
 
30
  class IndexParameters(TypedDict):
31
  recall: float # in this case 10-recall@10
 
274
  print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
275
  model.compile(mode="reduce-overhead")
276
 
277
+ def encode(query: list[str]) -> npt.NDArray[np.float16 | np.float32]:
278
+ return model.encode(query, prompt_name, normalize_embeddings=normalize)
279
+ if spaces:
280
+ encode = spaces.GPU(encode)
281
+
282
  # function signature: (expanded tuple of input batches) -> tuple of output batches
283
  def search(query: list[str]) -> tuple[list[str]]:
284
+ query_embedding = encode(query)
 
 
285
  distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
286
 
287
  faiss_ids_flat = list(chain(*faiss_ids))