File size: 974 Bytes
be12cc9
 
 
 
 
 
b6be18b
 
 
 
 
 
be12cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dataclasses
import torch
from qdrant_client import QdrantClient, models

from config import qdrant_location, qdrant_api_key

qdrant = QdrantClient(
    qdrant_location,
    api_key=qdrant_api_key,
    port=443,
    timeout=30,
)

def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[models.ScoredPoint]:
    hits = qdrant.search(
        collection_name="kanji",
        # query_vector=query_vector,
        query_vector=query_vector.numpy(),
        limit=limit,
        with_payload=True,
    )
    return hits

@dataclasses.dataclass
class SearchResult:
    kanji: str
    font: str
    score: float

def format_search_results(hits: list[models.ScoredPoint]) -> list[SearchResult]:
    formatted = []
    for point in hits:
        kanji, font = point.payload["kanji"], point.payload["font"]
        formatted.append(SearchResult(
            kanji = kanji,
            font = font,
            score = point.score,
        ))
    return formatted