|
import os |
|
import logging |
|
from tempfile import TemporaryFile |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import tator |
|
import inference |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
host = os.getenv('HOST') |
|
token = os.getenv('TOKEN') |
|
project_id = int(os.getenv('PROJECT_ID')) |
|
media_ids = [int(id_) for id_ in os.getenv('MEDIA_IDS').split(',')] |
|
frames_per_inference = int(os.getenv('FRAMES_PER_INFERENCE', 30)) |
|
|
|
|
|
api = tator.get_api(host, token) |
|
|
|
|
|
for media_id in media_ids: |
|
|
|
|
|
media = api.get_media(media_id) |
|
logger.info(f"Downloading {media.name}...") |
|
out_path = f"/tmp/{media.name}" |
|
for progress in tator.util.download_media(api, media, out_path): |
|
logger.info(f"Download progress: {progress}%") |
|
|
|
|
|
logger.info(f"Doing inference on {media.name}...") |
|
localizations = [] |
|
vid = cv2.VideoCapture(out_path) |
|
frame_number = 0 |
|
|
|
|
|
while True: |
|
ret, frame = vid.read() |
|
if not ret: |
|
break |
|
|
|
|
|
framefile = TemporaryFile(suffix='.jpg') |
|
im = Image.fromarray(frame) |
|
im.save(framefile) |
|
|
|
|
|
|
|
if frame_number % frames_per_inference == 0: |
|
|
|
spec = {} |
|
|
|
|
|
predictions = inference.run_inference(framefile) |
|
|
|
for i, r in predictions.pandas().xyxy[0].iterrows: |
|
|
|
spec['media_id'] = media_id |
|
spec['type'] = None |
|
spec['frame'] = frame_number |
|
|
|
x, y, x2, y2 = r['xmin'], r['ymin'], r['xmax'], r['ymax'] |
|
w, h = x2 - x, y2 - y |
|
|
|
spec['x'] = x |
|
spec['y'] = y |
|
spec['width'] = w |
|
spec['height'] = h |
|
spec['class_category'] = r['name'] |
|
spec['confidence'] = r['confidence'] |
|
|
|
localizations.append(spec) |
|
|
|
frame_number += 1 |
|
|
|
|
|
vid.release() |
|
|
|
logger.info(f"Uploading object detections on {media.name}...") |
|
|
|
|
|
num_created = 0 |
|
for response in tator.util.chunked_create(api.create_localization_list, |
|
project_id, |
|
localization_spec=localizations): |
|
num_created += len(response.id) |
|
|
|
|
|
logger.info(f"Successfully created {num_created} localizations on " |
|
f"{media.name}!") |
|
|
|
logger.info("-------------------------------------------------") |
|
|
|
logger.info(f"Completed inference on {len(media_ids)} files.") |