"""Gen cmat for de/en text.""" # pylint: disable=invalid-name, too-many-branches from typing import List, Optional import more_itertools as mit import numpy as np from tqdm import tqdm # from model_pool import load_model_s # from hf_model_s_cpu import model_s # load_model_s directly from st_mlbee.load_model_s import load_model_s # from logzero import logger from loguru import logger # from st_mlbee.cos_matrix2 import cos_matrix2 from .cos_matrix2 import cos_matrix2 _ = """ try: model_s = load_model_s() except Exception as exc: logger.erorr(exc) raise """ try: # model = model_s() # model = model_s(alive_bar_on=True) model = load_model_s() except Exception as _: logger.error(_) raise def gen_cmat( text1: List[str], text2: List[str], bsize: int = 50 ) -> np.ndarray: """Gen corr matrix for texts. Args: text1: typically '''...''' splitlines() text2: typically '''...''' splitlines() bsize: batch size, default 50 text1 = 'this is a test' text2 = 'another test' """ bsize = int(bsize) if bsize <= 0: bsize = 50 if isinstance(text1, str): text1 = [text1] if isinstance(text2, str): text1 = [text2] vec1 = [] vec2 = [] len1 = len(text1) len2 = len(text2) tot = len1 // bsize + bool(len1 % bsize) tot += len2 // bsize + bool(len2 % bsize) with tqdm(total=tot) as pbar: for chunk in mit.chunked(text1, bsize): try: vec = model.encode(chunk) except Exception as exc: logger.error(exc) raise vec1.extend(vec) pbar.update() for chunk in mit.chunked(text2, bsize): try: vec = model.encode(chunk) except Exception as exc: logger.error(exc) raise vec2.extend(vec) pbar.update() try: # note the order vec2, vec1 _ = cos_matrix2(np.array(vec2), np.array(vec1)) except Exception as exc: logger.exception(exc) raise return _