fadindashfr commited on
Commit
17a1b09
·
1 Parent(s): 4c5329d

add device = 'cpu'

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -1,9 +1,19 @@
1
  import torch
2
  from monai.bundle import ConfigParser
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
4
 
5
  parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
6
- parser.read_config(f="configs/inference.json") # read the config from specified JSON file
7
  parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
8
 
9
  inference = parser.get_parsed_content("inferer")
@@ -11,6 +21,9 @@ network = parser.get_parsed_content("network_def")
11
  preprocess = parser.get_parsed_content("preprocessing")
12
  state_dict = torch.load("models/model.pt")
13
  network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
 
 
 
14
  class_names = {
15
  0: "Other",
16
  1: "Inflammatory",
@@ -21,6 +34,7 @@ class_names = {
21
  def classify_image(image_file, label_file):
22
  data = {"image":image_file, "label":label_file}
23
  batch = preprocess(data)
 
24
  network.eval()
25
  with torch.no_grad():
26
  pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
 
1
  import torch
2
  from monai.bundle import ConfigParser
3
  import gradio as gr
4
+ import json
5
+
6
+ with open("configs/inference.json") as f:
7
+ inference_config = json.load(f)
8
+
9
+ device = torch.device('cpu')
10
+ if torch.cuda.is_available():
11
+ device = torch.device('cuda:0')
12
+
13
+ inference_config["device"] = device
14
 
15
  parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
16
+ parser.read_config(f=inference_config) # read the config from specified JSON file
17
  parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
18
 
19
  inference = parser.get_parsed_content("inferer")
 
21
  preprocess = parser.get_parsed_content("preprocessing")
22
  state_dict = torch.load("models/model.pt")
23
  network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
24
+ network = network.to(device)
25
+ network.eval()
26
+
27
  class_names = {
28
  0: "Other",
29
  1: "Inflammatory",
 
34
  def classify_image(image_file, label_file):
35
  data = {"image":image_file, "label":label_file}
36
  batch = preprocess(data)
37
+ batch['image'] = batch['image'].to(device)
38
  network.eval()
39
  with torch.no_grad():
40
  pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)