import os
os.system("git clone https://github.com/google-research/frame-interpolation")
import sys
sys.path.append("frame-interpolation")

import cv2
import numpy as np
import tensorflow as tf
import mediapy
from PIL import Image

import gradio as gr

from huggingface_hub import snapshot_download

from image_tools.sizes import resize_and_crop
from moviepy.editor import *


model = snapshot_download(repo_id="akhaliq/frame-interpolation-film-style")
from eval import interpolator, util
interpolator = interpolator.Interpolator(model, None)

ffmpeg_path = util.get_ffmpeg_path()
mediapy.set_ffmpeg(ffmpeg_path)


        
def do_interpolation(frame1, frame2, times_to_interpolate):
    print(frame1, frame2)
    input_frames = [frame1, frame2]
    #times_to_interpolate = 2
    frames = list(
        util.interpolate_recursively_from_files(
            input_frames, times_to_interpolate, interpolator))
    
    #print(frames)
    mediapy.write_video(f"{frame1}_to_{frame2}_out.mp4", frames, fps=12)
    return f"{frame1}_to_{frame2}_out.mp4"
    
def get_frames(video_in, step, name):
    frames = []
    #resize the video
    clip = VideoFileClip(video_in)
    
    #check fps
    if clip.fps > 30:
        print("vide rate is over 30, resetting to 30")
        clip_resized = clip.resize(height=512)
        clip_resized.write_videofile("video_resized.mp4", fps=30)
    else:
        print("video rate is OK")
        clip_resized = clip.resize(height=512)
        clip_resized.write_videofile("video_resized.mp4", fps=clip.fps)
    
    print("video resized to 512 height")
    
    # Opens the Video file with CV2
    cap= cv2.VideoCapture("video_resized.mp4")
    
    fps = cap.get(cv2.CAP_PROP_FPS)
    print("video fps: " + str(fps))
    i=0
    while(cap.isOpened()):
        ret, frame = cap.read()
        if ret == False:
            break
        cv2.imwrite(f"{name}_{step}{str(i)}.jpg",frame)
        frames.append(f"{name}_{step}{str(i)}.jpg")
        i+=1
    
    cap.release()
    cv2.destroyAllWindows()
    print("broke the video into frames")
    
    return frames, fps


def create_video(frames, fps, type):
    print("building video result")
    clip = ImageSequenceClip(frames, fps=fps)
    clip.write_videofile(type + "_result.mp4", fps=fps)
    
    return type + "_result.mp4"

    
def infer(video_in,interpolation,fps_output):
    
    
    # 1. break video into frames and get FPS
    break_vid = get_frames(video_in, "vid_input_frame", "origin")
    frames_list= break_vid[0]
    fps = break_vid[1]
    print(f"ORIGIN FPS: {fps}")
    n_frame = int(4*fps) #limited to 4 seconds
    #n_frame = len(frames_list)
    
    if n_frame >= len(frames_list):
        print("video is shorter than the cut value")
        n_frame = len(frames_list)
    
    # 2. prepare frames result arrays
    result_frames = []
    print("set stop frames to: " + str(n_frame))
    
    
    
    
    for idx, frame in enumerate(frames_list[0:int(n_frame)]):
        if idx < len(frames_list) - 1:
            next_frame = frames_list[idx+1]
            interpolated_frames = do_interpolation(frame, next_frame,interpolation) # should return a list of 3 interpolated frames
            break_interpolated_video = get_frames(interpolated_frames, "interpol",f"{idx}_")
            print(break_interpolated_video[0])
            for j, img in enumerate(break_interpolated_video[0][0:len(break_interpolated_video[0])-1]):
                print(f"IMG:{img}")
                os.rename(img, f"{frame}_to_{next_frame}_{j}.jpg")
                result_frames.append(f"{frame}_to_{next_frame}_{j}.jpg")
            
            print("frames " + str(idx) + " & " + str(idx+1) + "/" + str(n_frame) + ": done;")
            #print(f"CURRENT FRAMES: {result_frames}")
    result_frames.append(f"{frames_list[n_frame-1]}")
    final_vid = create_video(result_frames, fps_output, "interpolated")

    files = final_vid

    return final_vid, files

title="""
<div style="text-align: center; max-width: 500px; margin: 0 auto;">
        <div
        style="
            display: inline-flex;
            align-items: center;
            gap: 0.8rem;
            font-size: 1.75rem;
            margin-bottom: 10px;
        "
        >
        <h1 style="font-weight: 600; margin-bottom: 7px;">
            Video interpolation with FILM
        </h1>
        
        </div>
       <p> This space uses FILM to generate interpolation frames in a video you need to fluidify.<br />
       Generation is limited to 4 seconds, from the beginning of your video input.<br />
       <a style="display:inline-block" href="https://huggingface.co/spaces/fffiloni/video_frame_interpolation?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a> 
       </p>
    </div>
"""

with gr.Blocks() as demo:
    with gr.Column():
        gr.HTML(title)
        with gr.Row():
            with gr.Column():
                video_input = gr.Video(source="upload", type="filepath")
                with gr.Row():
                    interpolation = gr.Slider(minimum=1,maximum=4,step=1, value=1, label="Interpolation Steps")
                    fps_output = gr.Radio([8, 12, 24], label="FPS output", value=8)
                submit_btn = gr.Button("Submit")
            
            with gr.Column():
                video_output = gr.Video()
                file_output = gr.File()
    
    gr.Examples(
        examples=[["./examples/yoda-fps2.mp4", 1, 12]],
        fn=infer,
        inputs=[video_input,interpolation,fps_output],
        outputs=[video_output,file_output],
        cache_examples=True
    )
    
    submit_btn.click(fn=infer, inputs=[video_input,interpolation,fps_output], outputs=[video_output, file_output])

demo.launch()