from fastapi import FastAPI, UploadFile, Form, HTTPException from fastapi.responses import HTMLResponse, JSONResponse from moviepy.editor import VideoFileClip from PIL import Image import torch import numpy as np import time from typing import List import base64 from io import BytesIO from fastapi.middleware.cors import CORSMiddleware import json import os import sys import clip # Add your local CLIP folder to the Python path sys.path.insert(0, os.path.abspath("C:\new_ai\clip")) # Import from the local CLIP folder #from clip import load as clip_load #from clip.simple_tokenizer import tokenize as clip_tokenize # Initialize FastAPI app app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load CLIP model from local folder device = "cuda" if torch.cuda.is_available() else "cpu" print('Device: ', device) print("CUDA version: ", torch.version.cuda) model, preprocess = clip.load("ViT-B/32", device=device) print('Model Params Device: ',next(model.parameters()).device) # Should print 'cuda:0' def image_to_base64(img: Image.Image) -> str: buffered = BytesIO() img.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def format_time(seconds: float) -> str: if seconds >= 3600: hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = int(seconds % 60) return f"{hours}h {minutes}m {seconds}" elif seconds >= 60: minutes = int(seconds // 60) seconds = int(seconds % 60) return f"{minutes}m {seconds}" else: return f"{int(seconds)}" async def extract_frames(video_path: str, interval: int): video = VideoFileClip(video_path) frames = [] for t in np.arange(0, video.duration, interval): frame = video.get_frame(t) frame = Image.fromarray(frame) frames.append((t, frame)) return frames async def process_frames_in_batches(frames): if not frames: print("No frames were extracted.") return [] try: times, preprocessed_frames = zip(*[(time, preprocess(frame).unsqueeze(0)) for time, frame in frames]) except Exception as e: print(f"Error processing frames: {e}") return [] images_tensor = torch.cat(preprocessed_frames).to(device) print('Image Tensor Device: ', images_tensor.device) # Should print 'cuda:0' with torch.no_grad(): start_time = time.time() image_features = model.encode_image(images_tensor).float() end_time = time.time() print(f"Time taken to process 1 batch of images: {end_time - start_time:.4f} seconds") frame_features = [(time, image_features[i], frames[i][1]) for i, time in enumerate(times)] print(f"Processed {len(frame_features)} frames.") return frame_features async def process_text(text: str): text_input = clip.tokenize([text]).to(device) with torch.no_grad(): text_features = model.encode_text(text_input).float() return text_features async def find_all_matching_frames(frame_features, text_features, initial_threshold: float, min_threshold: float = 0.1): matched_frames = [] threshold = initial_threshold while not matched_frames and threshold >= min_threshold: matched_frames = [] for time, image_features, frame in frame_features: similarity = torch.cosine_similarity(text_features, image_features).item() if similarity >= threshold: frame_base64 = image_to_base64(frame) formatted_time = format_time(time) matched_frames.append({ "image_base64": frame_base64, "time": formatted_time, "accuracy": f"{np.round(similarity * 100, 0)}%" }) if not matched_frames: threshold -= 0.02 matched_frames.sort(key=lambda x: x['accuracy'], reverse=True) return matched_frames, threshold @app.post("/process_video/") async def process_video( video: UploadFile, prompt: str = Form(...), interval: int = Form(...), threshold: float = Form(...) ): try: print(f"Received file: {video.filename}") print(f"Prompt: {prompt}, Interval: {interval}, Threshold: {threshold}") if not video.filename: raise HTTPException(status_code=400, detail="No video file uploaded") start_time = time.time() temp_video_path = "temp_video.mp4" with open(temp_video_path, "wb") as buffer: buffer.write(await video.read()) print('Extracting Frames') frames = await extract_frames(temp_video_path, interval) print('Processing Frames') frame_features = await process_frames_in_batches(frames) print('Prompt encoding') text_features = await process_text(prompt) print('Finding Similarity') matched_frames, final_threshold = await find_all_matching_frames(frame_features, text_features, threshold) print('Completed!!!') total_end_time = time.time() total_time = total_end_time - start_time response_data = { "number_of_frames": len(frames), "matches": matched_frames, "final_threshold_used": final_threshold, "total_time_taken": format_time(total_time) } # with open('output.json', 'w') as outfile: # outfile.write(json.dumps(response_data, indent=2)) # Delete the temporary file after processing if os.path.exists(temp_video_path): os.remove(temp_video_path) print(f"{temp_video_path} has been deleted.") else: print(f"{temp_video_path} does not exist.") return response_data except Exception as e: print(f"Error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def read_root(): return {"message": "Hello World"}