pablorodriper commited on
Commit
4bb637b
·
1 Parent(s): 8fde263

Update predict.py

Browse files
Files changed (1) hide show
  1. utils/predict.py +31 -1
utils/predict.py CHANGED
@@ -23,12 +23,15 @@ def get_model():
23
 
24
  return model
25
 
 
26
  model = get_model()
27
  labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
28
 
 
29
  def predict_label(path):
30
  frames = load_video(path)
31
- prediction = model.predict(tf.expand_dims(frames, axis=0))[0]
 
32
  label = np.argmax(prediction, axis=0)
33
  label = labels[label]
34
 
@@ -52,3 +55,30 @@ def load_video(path):
52
  finally:
53
  cap.release()
54
  return np.array(frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  return model
25
 
26
+
27
  model = get_model()
28
  labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
29
 
30
+
31
  def predict_label(path):
32
  frames = load_video(path)
33
+ dataloader = prepare_dataloader(frames)
34
+ prediction = model.predict(dataloader)[0]
35
  label = np.argmax(prediction, axis=0)
36
  label = labels[label]
37
 
 
55
  finally:
56
  cap.release()
57
  return np.array(frames)
58
+
59
+
60
+ def prepare_dataloader(video):
61
+ video = tf.expand_dims(video, axis=0)
62
+ dataset = tf.data.Dataset.from_tensor_slices((video, np.array([0])))
63
+
64
+ dataloader = (
65
+ dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
66
+ .batch(1)
67
+ .prefetch(tf.data.AUTOTUNE)
68
+ )
69
+ return dataloader
70
+
71
+
72
+ @tf.function
73
+ def preprocess(frames: tf.Tensor, label: tf.Tensor):
74
+ """Preprocess the frames tensors and parse the labels."""
75
+ # Preprocess images
76
+ frames = tf.image.convert_image_dtype(
77
+ frames[
78
+ ..., tf.newaxis
79
+ ], # The new axis is to help for further processing with Conv3D layers
80
+ tf.float32,
81
+ )
82
+ # Parse label
83
+ label = tf.cast(label, tf.float32)
84
+ return frames, label