Spaces:
Running
Running
import dataclasses | |
import torch | |
from qdrant_client import QdrantClient, models | |
from config import qdrant_location, qdrant_api_key | |
qdrant = QdrantClient( | |
qdrant_location, | |
api_key=qdrant_api_key, | |
port=443, | |
timeout=30, | |
) | |
def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[models.ScoredPoint]: | |
hits = qdrant.search( | |
collection_name="kanji", | |
# query_vector=query_vector, | |
query_vector=query_vector.numpy(), | |
limit=limit, | |
with_payload=True, | |
) | |
return hits | |
class SearchResult: | |
kanji: str | |
font: str | |
score: float | |
def format_search_results(hits: list[models.ScoredPoint]) -> list[SearchResult]: | |
formatted = [] | |
for point in hits: | |
kanji, font = point.payload["kanji"], point.payload["font"] | |
formatted.append(SearchResult( | |
kanji = kanji, | |
font = font, | |
score = point.score, | |
)) | |
return formatted | |