pablorodriper's picture
Update predict.py
4bb637b
import cv2
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
from tensorflow.keras.optimizers import Adam
from .constants import LEARNING_RATE
def get_model():
"""
Download the model from the Hugging Face Hub and compile it.
"""
model = from_pretrained_keras("pablorodriper/video-vision-transformer")
model.compile(
optimizer=Adam(learning_rate=LEARNING_RATE),
loss="sparse_categorical_crossentropy",
# metrics=[
# keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
# keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
# ],
)
return model
model = get_model()
labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas']
def predict_label(path):
frames = load_video(path)
dataloader = prepare_dataloader(frames)
prediction = model.predict(dataloader)[0]
label = np.argmax(prediction, axis=0)
label = labels[label]
return label
def load_video(path):
"""
Load video from path and return a list of frames.
The video is converted to grayscale because it is the format expected by the model.
"""
cap = cv2.VideoCapture(path)
frames = []
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frames.append(frame)
finally:
cap.release()
return np.array(frames)
def prepare_dataloader(video):
video = tf.expand_dims(video, axis=0)
dataset = tf.data.Dataset.from_tensor_slices((video, np.array([0])))
dataloader = (
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(1)
.prefetch(tf.data.AUTOTUNE)
)
return dataloader
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
"""Preprocess the frames tensors and parse the labels."""
# Preprocess images
frames = tf.image.convert_image_dtype(
frames[
..., tf.newaxis
], # The new axis is to help for further processing with Conv3D layers
tf.float32,
)
# Parse label
label = tf.cast(label, tf.float32)
return frames, label