Spaces:
Build error
Build error
pablorodriper
commited on
Commit
·
4bb637b
1
Parent(s):
8fde263
Update predict.py
Browse files- utils/predict.py +31 -1
utils/predict.py
CHANGED
@@ -23,12 +23,15 @@ def get_model():
|
|
23 |
|
24 |
return model
|
25 |
|
|
|
26 |
model = get_model()
|
27 |
labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
|
28 |
|
|
|
29 |
def predict_label(path):
|
30 |
frames = load_video(path)
|
31 |
-
|
|
|
32 |
label = np.argmax(prediction, axis=0)
|
33 |
label = labels[label]
|
34 |
|
@@ -52,3 +55,30 @@ def load_video(path):
|
|
52 |
finally:
|
53 |
cap.release()
|
54 |
return np.array(frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
return model
|
25 |
|
26 |
+
|
27 |
model = get_model()
|
28 |
labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
|
29 |
|
30 |
+
|
31 |
def predict_label(path):
|
32 |
frames = load_video(path)
|
33 |
+
dataloader = prepare_dataloader(frames)
|
34 |
+
prediction = model.predict(dataloader)[0]
|
35 |
label = np.argmax(prediction, axis=0)
|
36 |
label = labels[label]
|
37 |
|
|
|
55 |
finally:
|
56 |
cap.release()
|
57 |
return np.array(frames)
|
58 |
+
|
59 |
+
|
60 |
+
def prepare_dataloader(video):
|
61 |
+
video = tf.expand_dims(video, axis=0)
|
62 |
+
dataset = tf.data.Dataset.from_tensor_slices((video, np.array([0])))
|
63 |
+
|
64 |
+
dataloader = (
|
65 |
+
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
|
66 |
+
.batch(1)
|
67 |
+
.prefetch(tf.data.AUTOTUNE)
|
68 |
+
)
|
69 |
+
return dataloader
|
70 |
+
|
71 |
+
|
72 |
+
@tf.function
|
73 |
+
def preprocess(frames: tf.Tensor, label: tf.Tensor):
|
74 |
+
"""Preprocess the frames tensors and parse the labels."""
|
75 |
+
# Preprocess images
|
76 |
+
frames = tf.image.convert_image_dtype(
|
77 |
+
frames[
|
78 |
+
..., tf.newaxis
|
79 |
+
], # The new axis is to help for further processing with Conv3D layers
|
80 |
+
tf.float32,
|
81 |
+
)
|
82 |
+
# Parse label
|
83 |
+
label = tf.cast(label, tf.float32)
|
84 |
+
return frames, label
|