import gradio as gr import torch import onnxruntime as ort import os import gdown from PIL import Image import numpy as np means = np.array([0.485, 0.456, 0.406]) stds = np.array([0.229, 0.224, 0.225]) # Define the model URL and output path model_url = "https://drive.google.com/file/d/18HYScsRJuRmfzL0E0BW35uaA542Vd5M5/view?usp=sharing" model_path = os.path.join(os.getcwd(),"bone_age_model.onnx") # Check if the model file exists and download if it does not if not os.path.exists(model_path): gdown.download(model_url, model_path, fuzzy=True, quiet=False) # Initialize the ONNX session session = ort.InferenceSession(model_path) # Define the inference function def inference(sample_name): sample_path = os.path.join(os.getcwd(),'samples', f'{sample_name}.pth') sample = torch.load(sample_path) age = sample['boneage'].item() outputs = session.run(None, {"input": sample['path'].numpy()}) predicted_age = (outputs[0]*41.172)+127.329 # Get the image data from the MetaTensor and convert it to a format PIL can handle image_data = sample['path'][0].numpy() # Denormalize the image data for i in range(3): # Assuming the image has 3 channels image_data[i,:,:] = image_data[i,:,:] * stds[i] + means[i] # Rescale to [0, 1] # Convert to [0, 255] and to uint8 image_data = (image_data * 255).astype(np.uint8) # Remove any singleton dimensions if necessary image_data = np.moveaxis(image_data,0,-1) image = Image.fromarray(image_data) return age, round(predicted_age[0][0]), image # List of sample file names sample_files = sorted(os.listdir(os.path.join(os.getcwd(),'samples'))) sample_names = [os.path.basename(x).split('.pth')[0] for x in sample_files] # Create Gradio interface dropdown = gr.Dropdown(choices=sample_names, label="Select a sample") iface = gr.Interface( fn=inference, inputs=dropdown, outputs=[ gr.Textbox(label="Bone Age"), gr.Textbox(label="Predicted Bone Age"), gr.Image(label="Image") ], title="Bone Age Prediction", description="Select a sample from the dropdown to see the bone age and predicted bone age." ) # Launch the app iface.launch()