HEp2 / app.py
admin
sync ms
900716e
raw
history blame
2.78 kB
import os
import torch
import modelscope
import huggingface_hub
import gradio as gr
from PIL import Image
from torchvision.transforms import transforms
EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
ZH2EN = {
"上传细胞图像": "Upload a cell picture",
"状态栏": "Status",
"图片名": "Picture name",
"识别结果": "Recognition result",
"请上传 PNG 格式的 HEp2 细胞图片": "It is recommended to upload HEp2 cell images in PNG format.",
}
def _L(zh_txt: str):
return ZH2EN[zh_txt] if EN_US else zh_txt
MODEL_DIR = (
huggingface_hub.snapshot_download(
"Genius-Society/HEp2",
cache_dir="./__pycache__",
)
if EN_US
else modelscope.snapshot_download(
"Genius-Society/HEp2",
cache_dir="./__pycache__",
)
)
TRANSLATE = {
"Centromere": "着丝粒",
"Golgi": "高尔基体",
"Homogeneous": "同质",
"NuMem": "记忆体",
"Nucleolar": "核仁",
"Speckled": "斑核",
}
CLASSES = list(TRANSLATE.keys())
def embeding(img_path: str):
compose = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.RandomAffine(5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
img = Image.open(img_path).convert("RGB")
return compose(img)
def infer(target: str):
status = "Success"
filename = result = None
try:
model = torch.load(
f"{MODEL_DIR}/save.pt",
map_location=torch.device("cpu"),
weights_only=False,
)
if not target:
raise ValueError("请上传细胞图片")
torch.cuda.empty_cache()
input: torch.Tensor = embeding(target)
output: torch.Tensor = model(input.unsqueeze(0))
predict = torch.max(output.data, 1)[1]
filename = os.path.basename(target)
result = CLASSES[predict] if EN_US else TRANSLATE[CLASSES[predict]]
except Exception as e:
status = f"{e}"
return status, filename, result
if __name__ == "__main__":
example_imgs = []
for cls in CLASSES:
example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png")
gr.Interface(
fn=infer,
inputs=gr.Image(type="filepath", label=_L("上传细胞图像")),
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.Textbox(label=_L("图片名"), show_copy_button=True),
gr.Textbox(label=_L("识别结果"), show_copy_button=True),
],
title=_L("请上传 PNG 格式的 HEp2 细胞图片"),
examples=example_imgs,
flagging_mode="never",
cache_examples=False,
).launch(ssr_mode=False)