kanji_lookup / database.py
etrotta's picture
Change the vector database used and embed the embeddings within the program
63a1db6
raw
history blame
843 Bytes
import torch
import lancedb
from lancedb.pydantic import LanceModel
import pydantic
# import time
from config import lancedb_location
db = lancedb.connect(lancedb_location)
table = db.open_table("kanji")
class SearchResult(LanceModel):
kanji: str
distance: float = pydantic.Field(validation_alias=pydantic.AliasChoices('distance', '_distance'))
def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[SearchResult]:
# start = time.perf_counter()
results = (
table
.search(query_vector.numpy(), vector_column_name="vector", query_type="vector")
.limit(limit)
# .to_pydantic(SearchResult) # type: ignore
.to_list()
)
# end = time.perf_counter()
# print(f"Searched in {end - start:.3f}")
return [SearchResult.model_validate(result) for result in results]