Medvira commited on
Commit
b5f7aea
·
verified ·
1 Parent(s): ea75e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -1,31 +1,51 @@
1
  import gradio as gr
2
  import torch
3
  import onnxruntime as ort
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Initialize the ONNX session
6
  session = ort.InferenceSession("/content/bone_age_model.onnx")
7
 
8
  # Define the inference function
9
  def inference(sample_name):
10
- sample = torch.load(f"/content/samples/{sample_name}")
 
11
  age = sample['boneage'].item()
12
  outputs = session.run(None, {"input": sample['path'].numpy()})
13
  predicted_age = (outputs[0]*41.172)+127.329
 
 
 
 
14
  return {
15
  'Bone age': age,
16
- 'Predicted Bone age': predicted_age[0][0]
 
17
  }
18
 
19
  # List of sample file names
20
- sample_files = [f"sample_{i}.pth" for i in range(1, 11)]
21
-
22
  # Create Gradio interface
23
- dropdown = gr.inputs.Dropdown(choices=sample_files, label="Select a sample")
24
 
25
  iface = gr.Interface(
26
  fn=inference,
27
  inputs=dropdown,
28
- outputs=[gr.outputs.Textbox(label="Bone Age"), gr.outputs.Textbox(label="Predicted Bone Age")],
 
 
 
 
29
  title="Bone Age Prediction",
30
  description="Select a sample from the dropdown to see the bone age and predicted bone age."
31
  )
 
1
  import gradio as gr
2
  import torch
3
  import onnxruntime as ort
4
+ import os
5
+ import gdown
6
+
7
+ # Define the model URL and output path
8
+ model_url = "https://drive.google.com/file/d/18HYScsRJuRmfzL0E0BW35uaA542Vd5M5/view?usp=sharing"
9
+ model_path = os.path.join(os.getcwd(),"bone_age_model.onnx")
10
+
11
+ # Check if the model file exists and download if it does not
12
+ if not os.path.exists(model_path):
13
+ gdown.download(model_url, model_path, quiet=False)
14
 
15
  # Initialize the ONNX session
16
  session = ort.InferenceSession("/content/bone_age_model.onnx")
17
 
18
  # Define the inference function
19
  def inference(sample_name):
20
+ sample_path = os.path.join(os.getcwd(),f'{sample_name}.pth')
21
+ sample = torch.load(sample_path)
22
  age = sample['boneage'].item()
23
  outputs = session.run(None, {"input": sample['path'].numpy()})
24
  predicted_age = (outputs[0]*41.172)+127.329
25
+ # Get the image path and load the image
26
+ image_path = sample['path'][0]
27
+ image = Image.open(image_path)
28
+
29
  return {
30
  'Bone age': age,
31
+ 'Predicted Bone age': predicted_age[0][0],
32
+ 'Image': image
33
  }
34
 
35
  # List of sample file names
36
+ sample_files = sorted(os.listdir(os.path.join(os.getcwd(),'samples','*.pth')))
37
+ sample_names = [os.path.basename(x).split('.pth')[0] for x in sample_files]
38
  # Create Gradio interface
39
+ dropdown = gr.inputs.Dropdown(choices=sample_names, label="Select a sample")
40
 
41
  iface = gr.Interface(
42
  fn=inference,
43
  inputs=dropdown,
44
+ outputs=[
45
+ gr.outputs.Textbox(label="Bone Age"),
46
+ gr.outputs.Textbox(label="Predicted Bone Age"),
47
+ gr.outputs.Image(label="Image")
48
+ ],
49
  title="Bone Age Prediction",
50
  description="Select a sample from the dropdown to see the bone age and predicted bone age."
51
  )