import dataclasses import torch from qdrant_client import QdrantClient, models from config import qdrant_location, qdrant_api_key qdrant = QdrantClient( qdrant_location, api_key=qdrant_api_key, port=443, timeout=30, ) def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[models.ScoredPoint]: hits = qdrant.search( collection_name="kanji", # query_vector=query_vector, query_vector=query_vector.numpy(), limit=limit, with_payload=True, ) return hits @dataclasses.dataclass class SearchResult: kanji: str font: str score: float def format_search_results(hits: list[models.ScoredPoint]) -> list[SearchResult]: formatted = [] for point in hits: kanji, font = point.payload["kanji"], point.payload["font"] formatted.append(SearchResult( kanji = kanji, font = font, score = point.score, )) return formatted