File size: 3,010 Bytes
278c80b
 
 
 
 
 
 
 
 
 
 
cf8b6d5
278c80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
import random
import warnings
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms
from modelscope import snapshot_download


MODEL_DIR = snapshot_download("MuGemSt/svhn", cache_dir="./__pycache__")


def infer(input_img: str, checkpoint_file: str):
    try:
        model = Model()
        model.restore(f"{MODEL_DIR}/{checkpoint_file}")
        outstr = ""
        with torch.no_grad():
            transform = transforms.Compose(
                [
                    transforms.Resize([64, 64]),
                    transforms.CenterCrop([54, 54]),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
            image = Image.open(input_img)
            image = image.convert("RGB")
            image = transform(image)
            images = image.unsqueeze(dim=0)
            (
                length_logits,
                digit1_logits,
                digit2_logits,
                digit3_logits,
                digit4_logits,
                digit5_logits,
            ) = model.eval()(images)
            length_prediction = length_logits.max(1)[1]
            digit1_prediction = digit1_logits.max(1)[1]
            digit2_prediction = digit2_logits.max(1)[1]
            digit3_prediction = digit3_logits.max(1)[1]
            digit4_prediction = digit4_logits.max(1)[1]
            digit5_prediction = digit5_logits.max(1)[1]
            output = [
                digit1_prediction.item(),
                digit2_prediction.item(),
                digit3_prediction.item(),
                digit4_prediction.item(),
                digit5_prediction.item(),
            ]

            for i in range(length_prediction.item()):
                outstr += str(output[i])

        return outstr

    except Exception as e:
        return f"{e}"


def get_files(dir_path=MODEL_DIR, ext=".pth"):
    files_and_folders = os.listdir(dir_path)
    outputs = []
    for file in files_and_folders:
        if file.endswith(ext):
            outputs.append(file)

    return outputs


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    models = get_files()
    images = get_files(f"{MODEL_DIR}/examples", ".png")
    samples = []
    for img in images:
        samples.append(
            [
                f"{MODEL_DIR}/examples/{img}",
                models[random.randint(0, len(models) - 1)],
            ]
        )

    gr.Interface(
        fn=infer,
        inputs=[
            gr.Image(label="上传图片 Upload an image", type="filepath"),
            gr.Dropdown(
                label="选择权重 Select a model",
                choices=models,
                value=models[0],
            ),
        ],
        outputs=gr.Textbox(label="识别结果 Recognition result", show_copy_button=True),
        examples=samples,
        allow_flagging="never",
        cache_examples=False,
    ).launch()