import os import torch import random import warnings import gradio as gr from PIL import Image from model import Model from torchvision import transforms from modelscope import snapshot_download MODEL_DIR = snapshot_download("MuGemSt/svhn", cache_dir="./__pycache__") def infer(input_img: str, checkpoint_file: str): try: model = Model() model.restore(f"{MODEL_DIR}/{checkpoint_file}") outstr = "" with torch.no_grad(): transform = transforms.Compose( [ transforms.Resize([64, 64]), transforms.CenterCrop([54, 54]), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) image = Image.open(input_img) image = image.convert("RGB") image = transform(image) images = image.unsqueeze(dim=0) ( length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits, ) = model.eval()(images) length_prediction = length_logits.max(1)[1] digit1_prediction = digit1_logits.max(1)[1] digit2_prediction = digit2_logits.max(1)[1] digit3_prediction = digit3_logits.max(1)[1] digit4_prediction = digit4_logits.max(1)[1] digit5_prediction = digit5_logits.max(1)[1] output = [ digit1_prediction.item(), digit2_prediction.item(), digit3_prediction.item(), digit4_prediction.item(), digit5_prediction.item(), ] for i in range(length_prediction.item()): outstr += str(output[i]) return outstr except Exception as e: return f"{e}" def get_files(dir_path=MODEL_DIR, ext=".pth"): files_and_folders = os.listdir(dir_path) outputs = [] for file in files_and_folders: if file.endswith(ext): outputs.append(file) return outputs if __name__ == "__main__": warnings.filterwarnings("ignore") models = get_files() images = get_files(f"{MODEL_DIR}/examples", ".png") samples = [] for img in images: samples.append( [ f"{MODEL_DIR}/examples/{img}", models[random.randint(0, len(models) - 1)], ] ) gr.Interface( fn=infer, inputs=[ gr.Image(label="上传图片 Upload an image", type="filepath"), gr.Dropdown( label="选择权重 Select a model", choices=models, value=models[0], ), ], outputs=gr.Textbox(label="识别结果 Recognition result", show_copy_button=True), examples=samples, allow_flagging="never", cache_examples=False, ).launch()