import cv2
import gradio as gr
import imutils
import numpy as np
import torch
from PIL import Image
from cnn3d_model import  load_model
import torchvision.transforms as transforms


def parse_video(video_file):
    """A utility to parse the input videos.
    Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
    """
    vs = cv2.VideoCapture(video_file)

    # try to determine the total number of frames in the video file
    try:
        prop = (
            cv2.cv.CV_CAP_PROP_FRAME_COUNT
            if imutils.is_cv2()
            else cv2.CAP_PROP_FRAME_COUNT
        )
        total = int(vs.get(prop))
        print("[INFO] {} total frames in video".format(total))

    # an error occurred while trying to determine the total
    # number of frames in the video file
    except:
        print("[INFO] could not determine # of frames in video")
        print("[INFO] no approx. completion time can be provided")
        total = -1

    frames = []

    # loop over frames from the video file stream
    while True:
        # read the next frame from the file
        (grabbed, frame) = vs.read()
        if frame is not None:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        # if the frame was not grabbed, then we have reached the end
        # of the stream
        if not grabbed:
            break

    return frames

def pil_parser(video_file):

    model = load_model()
    # cv2 parsing

    dummy_frames = parse_video(video_file)

    
    X = []
    frames = np.arange(2,62,2)
    use_transform : transforms.Compose =transforms.Compose([transforms.Resize([256, 342]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5])])

    for i in frames:
        image = Image.fromarray(dummy_frames[i]).convert('L')

        if use_transform is not None:
            image = use_transform(image)
        else:
            image = transforms.ToTensor()(image)
        X.append(image)
    X = torch.stack(X, dim=1).unsqueeze(0)

    out = model(X)

    #return 'shape is : '+ str(X.shape) 
    return 'viscosity : ' + str(round(out.item(),1)) + ' cp_2'

example_list=[
        ["2350.mp4"],
        ["2300.mp4"],
    ]
gr.Interface(
    fn=pil_parser,
    inputs=gr.Video(label="Upload a video file"),
    outputs="text",
    examples=example_list,
    title="Viscosity Regression From Video Data",
    description=(
        "Gradio demo for Video Regression"
    ),
    allow_flagging='never',
    
).launch()