Spaces:
Build error
Build error
fadindashfr
commited on
Commit
·
17a1b09
1
Parent(s):
4c5329d
add device = 'cpu'
Browse files
app.py
CHANGED
@@ -1,9 +1,19 @@
|
|
1 |
import torch
|
2 |
from monai.bundle import ConfigParser
|
3 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
|
6 |
-
parser.read_config(f=
|
7 |
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
|
8 |
|
9 |
inference = parser.get_parsed_content("inferer")
|
@@ -11,6 +21,9 @@ network = parser.get_parsed_content("network_def")
|
|
11 |
preprocess = parser.get_parsed_content("preprocessing")
|
12 |
state_dict = torch.load("models/model.pt")
|
13 |
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
|
|
|
|
|
|
|
14 |
class_names = {
|
15 |
0: "Other",
|
16 |
1: "Inflammatory",
|
@@ -21,6 +34,7 @@ class_names = {
|
|
21 |
def classify_image(image_file, label_file):
|
22 |
data = {"image":image_file, "label":label_file}
|
23 |
batch = preprocess(data)
|
|
|
24 |
network.eval()
|
25 |
with torch.no_grad():
|
26 |
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
|
|
|
1 |
import torch
|
2 |
from monai.bundle import ConfigParser
|
3 |
import gradio as gr
|
4 |
+
import json
|
5 |
+
|
6 |
+
with open("configs/inference.json") as f:
|
7 |
+
inference_config = json.load(f)
|
8 |
+
|
9 |
+
device = torch.device('cpu')
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
device = torch.device('cuda:0')
|
12 |
+
|
13 |
+
inference_config["device"] = device
|
14 |
|
15 |
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
|
16 |
+
parser.read_config(f=inference_config) # read the config from specified JSON file
|
17 |
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
|
18 |
|
19 |
inference = parser.get_parsed_content("inferer")
|
|
|
21 |
preprocess = parser.get_parsed_content("preprocessing")
|
22 |
state_dict = torch.load("models/model.pt")
|
23 |
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
|
24 |
+
network = network.to(device)
|
25 |
+
network.eval()
|
26 |
+
|
27 |
class_names = {
|
28 |
0: "Other",
|
29 |
1: "Inflammatory",
|
|
|
34 |
def classify_image(image_file, label_file):
|
35 |
data = {"image":image_file, "label":label_file}
|
36 |
batch = preprocess(data)
|
37 |
+
batch['image'] = batch['image'].to(device)
|
38 |
network.eval()
|
39 |
with torch.no_grad():
|
40 |
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
|