Spaces:
Running
Running
import os | |
import torch | |
import modelscope | |
import huggingface_hub | |
import gradio as gr | |
from PIL import Image | |
from torchvision.transforms import transforms | |
EN_US = os.getenv("LANG") != "zh_CN.UTF-8" | |
ZH2EN = { | |
"上传细胞图像": "Upload a cell picture", | |
"状态栏": "Status", | |
"图片名": "Picture name", | |
"识别结果": "Recognition result", | |
"请上传 PNG 格式的 HEp2 细胞图片": "It is recommended to upload HEp2 cell images in PNG format.", | |
} | |
def _L(zh_txt: str): | |
return ZH2EN[zh_txt] if EN_US else zh_txt | |
MODEL_DIR = ( | |
huggingface_hub.snapshot_download( | |
"Genius-Society/HEp2", | |
cache_dir="./__pycache__", | |
) | |
if EN_US | |
else modelscope.snapshot_download( | |
"Genius-Society/HEp2", | |
cache_dir="./__pycache__", | |
) | |
) | |
TRANSLATE = { | |
"Centromere": "着丝粒", | |
"Golgi": "高尔基体", | |
"Homogeneous": "同质", | |
"NuMem": "记忆体", | |
"Nucleolar": "核仁", | |
"Speckled": "斑核", | |
} | |
CLASSES = list(TRANSLATE.keys()) | |
def embeding(img_path: str): | |
compose = transforms.Compose( | |
[ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.RandomAffine(5), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
img = Image.open(img_path).convert("RGB") | |
return compose(img) | |
def infer(target: str): | |
status = "Success" | |
filename = result = None | |
try: | |
model = torch.load( | |
f"{MODEL_DIR}/save.pt", | |
map_location=torch.device("cpu"), | |
weights_only=False, | |
) | |
if not target: | |
raise ValueError("请上传细胞图片") | |
torch.cuda.empty_cache() | |
input: torch.Tensor = embeding(target) | |
output: torch.Tensor = model(input.unsqueeze(0)) | |
predict = torch.max(output.data, 1)[1] | |
filename = os.path.basename(target) | |
result = CLASSES[predict] if EN_US else TRANSLATE[CLASSES[predict]] | |
except Exception as e: | |
status = f"{e}" | |
return status, filename, result | |
if __name__ == "__main__": | |
example_imgs = [] | |
for cls in CLASSES: | |
example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png") | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Image(type="filepath", label=_L("上传细胞图像")), | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.Textbox(label=_L("图片名"), show_copy_button=True), | |
gr.Textbox(label=_L("识别结果"), show_copy_button=True), | |
], | |
title=_L("请上传 PNG 格式的 HEp2 细胞图片"), | |
examples=example_imgs, | |
flagging_mode="never", | |
cache_examples=False, | |
).launch(ssr_mode=False) | |