Spaces:
Runtime error
Runtime error
bug fixes
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def predict(text: str = sample_text, top_k: int=3):
|
|
50 |
# query = prepare_query(tokenizer, text)
|
51 |
index_data, faiss_index = index
|
52 |
# takes only the [CLS] embedding (for now)
|
53 |
-
query = model(text)[0][0].numpy().reshape(1, -1)
|
54 |
|
55 |
scores, indices = faiss_index.search(query, top_k)
|
56 |
scores, indices = scores.tolist(), indices.tolist()
|
|
|
50 |
# query = prepare_query(tokenizer, text)
|
51 |
index_data, faiss_index = index
|
52 |
# takes only the [CLS] embedding (for now)
|
53 |
+
query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
|
54 |
|
55 |
scores, indices = faiss_index.search(query, top_k)
|
56 |
scores, indices = scores.tolist(), indices.tolist()
|