import gradio as gr import spaces import torch from transformers import AutoModel def crop_mammo(img, model, device): img_shape = torch.tensor([img.shape[:2]]).to(device) x = model.preprocess(img) x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device) with torch.inference_mode(): coords = model(x, img_shape) coords = coords[0].cpu().numpy() x, y, w, h = coords return img[y : y + h, x : x + w] @spaces.GPU def predict(cc, mlo): input_dict = {} if cc is not None: input_dict["cc"] = crop_mammo(cc, crop_model, device) if mlo is not None: input_dict["mlo"] = crop_mammo(mlo, crop_model, device) with torch.inference_mode(): output = model(input_dict, device=device) cancer_pred = {"Cancer Score": output["cancer"][0].item()} density_labels = ["A", "B", "C", "D"] density_pred = {label: score for label, score in zip(density_labels, output["density"][0].cpu().numpy())} return cancer_pred, density_pred cc_view = gr.Image(label="CC View", image_mode="L") mlo_view = gr.Image(label="MLO View", image_mode="L") cancer_label = gr.Label(label="Cancer", show_label=True, show_heading=False) density_label = gr.Label(label="Density", show_label=True, show_heading=True) with gr.Blocks() as demo: gr.Markdown( """ # Deep Learning Model for Screening Mammography This model predicts the likelihood of breast cancer from a standard two-view 2D screening mammography study, as well as breast density. Read more about the model here: This model was trained on pathology results (cancer versus no cancer) and does not produce a BI-RADS score. Supplying both CC and MLO views will result in the best prediction. However, the model will still work if only 1 view is provided. The example mammogram is taken from: Mohammad Niknejad, Radiopaedia.org, from the case rID: 147729. This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected health information, as this demonstration is not compliant with patient privacy laws. Created by: Ian Pan, Last updated: January 20, 2025 """ ) gr.Interface( fn=predict, inputs=[cc_view, mlo_view], outputs=[cancer_label, density_label], examples=[["examples/cc.jpg", "examples/mlo.jpg"]], cache_examples=True, ) if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device `{device}` ...") crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True) model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True) crop_model, model = crop_model.eval().to(device), model.eval().to(device) demo.launch(share=True)