Spaces:
Running
Running
File size: 3,014 Bytes
278c80b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import torch
import random
import warnings
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms
from modelscope import snapshot_download
MODEL_DIR = snapshot_download("MuGeminorum/svhn", cache_dir="./__pycache__")
def infer(input_img: str, checkpoint_file: str):
try:
model = Model()
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
outstr = ""
with torch.no_grad():
transform = transforms.Compose(
[
transforms.Resize([64, 64]),
transforms.CenterCrop([54, 54]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
image = Image.open(input_img)
image = image.convert("RGB")
image = transform(image)
images = image.unsqueeze(dim=0)
(
length_logits,
digit1_logits,
digit2_logits,
digit3_logits,
digit4_logits,
digit5_logits,
) = model.eval()(images)
length_prediction = length_logits.max(1)[1]
digit1_prediction = digit1_logits.max(1)[1]
digit2_prediction = digit2_logits.max(1)[1]
digit3_prediction = digit3_logits.max(1)[1]
digit4_prediction = digit4_logits.max(1)[1]
digit5_prediction = digit5_logits.max(1)[1]
output = [
digit1_prediction.item(),
digit2_prediction.item(),
digit3_prediction.item(),
digit4_prediction.item(),
digit5_prediction.item(),
]
for i in range(length_prediction.item()):
outstr += str(output[i])
return outstr
except Exception as e:
return f"{e}"
def get_files(dir_path=MODEL_DIR, ext=".pth"):
files_and_folders = os.listdir(dir_path)
outputs = []
for file in files_and_folders:
if file.endswith(ext):
outputs.append(file)
return outputs
if __name__ == "__main__":
warnings.filterwarnings("ignore")
models = get_files()
images = get_files(f"{MODEL_DIR}/examples", ".png")
samples = []
for img in images:
samples.append(
[
f"{MODEL_DIR}/examples/{img}",
models[random.randint(0, len(models) - 1)],
]
)
gr.Interface(
fn=infer,
inputs=[
gr.Image(label="上传图片 Upload an image", type="filepath"),
gr.Dropdown(
label="选择权重 Select a model",
choices=models,
value=models[0],
),
],
outputs=gr.Textbox(label="识别结果 Recognition result", show_copy_button=True),
examples=samples,
allow_flagging="never",
cache_examples=False,
).launch()
|