ajanz commited on
Commit
efae79d
·
1 Parent(s): 4d716db
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -50,7 +50,7 @@ def predict(text: str = sample_text, top_k: int=3):
50
  # query = prepare_query(tokenizer, text)
51
  index_data, faiss_index = index
52
  # takes only the [CLS] embedding (for now)
53
- query = model(text)[0][0].numpy().reshape(1, -1)
54
 
55
  scores, indices = faiss_index.search(query, top_k)
56
  scores, indices = scores.tolist(), indices.tolist()
 
50
  # query = prepare_query(tokenizer, text)
51
  index_data, faiss_index = index
52
  # takes only the [CLS] embedding (for now)
53
+ query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
54
 
55
  scores, indices = faiss_index.search(query, top_k)
56
  scores, indices = scores.tolist(), indices.tolist()