import os import torch import gradio as gr from PIL import Image from torchvision.transforms import transforms from modelscope import snapshot_download MODEL_DIR = snapshot_download("Genius-Society/HEp2", cache_dir="./__pycache__") CLASSES = [ "Centromere", "Golgi", "Homogeneous", "NuMem", "Nucleolar", "Speckled", ] def embeding(img_path: str): compose = transforms.Compose( [ transforms.Resize(224), transforms.CenterCrop(224), transforms.RandomAffine(5), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) img = Image.open(img_path).convert("RGB") return compose(img) def infer(target: str): model = torch.load(f"{MODEL_DIR}/save.pt", map_location=torch.device("cpu")) if not target: return None, "Please upload a cell picture!" torch.cuda.empty_cache() input: torch.Tensor = embeding(target) output: torch.Tensor = model(input.unsqueeze(0)) predict = torch.max(output.data, 1)[1] return os.path.basename(target), CLASSES[predict] if __name__ == "__main__": example_imgs = [] for cls in CLASSES: example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png") with gr.Blocks() as demo: gr.Interface( fn=infer, inputs=gr.Image(type="filepath", label="Upload a cell picture"), outputs=[ gr.Textbox(label="Picture name", show_copy_button=True), gr.Textbox(label="Recognition result", show_copy_button=True), ], title="It is recommended to upload HEp2 cell images in PNG format.", examples=example_imgs, flagging_mode="never", cache_examples=False, ) demo.launch()