import gradio as gr
import spaces
import torch

from ultralytics import YOLO
from PIL import Image
import supervision as sv
import numpy as np

@spaces.GPU
def yolov8_inference(
    image,
    selected_labels_list
):
    """
    YOLOv8 inference function
    Args:
        image: Input image
        model_path: Path to the model
        image_size: Image size
        conf_threshold: Confidence threshold
        iou_threshold: IOU threshold
    Returns:
        Rendered image
    """
    model = YOLO('erax_nsfw_v1.pt').to('cuda')
    # set model parameters
    model.overrides['conf'] = 0.3  # NMS confidence threshold
    model.overrides['iou'] = 0.2  # NMS IoU threshold
    model.overrides['agnostic_nms'] = False  # NMS class-agnostic
    model.overrides['max_det'] = 1000  # maximum number of detections per image

    results = model([image])
    for result in results:
        annotated_image = result.orig_img.copy()
        h, w = annotated_image.shape[:2]
        anchor = h if h > w else w
        
        # Create the dictionary by filtering list1 and list2 based on list3
        selected_classes = [[0, 1, 2, 3, 4][["anus", "make_love", "nipple", "penis", "vagina"].index(item)] for item in selected_labels_list]
        
        # print(filtered_mapping)
        # selected_classes = [0, 1, 2, 3, 4] # all classes
        detections = sv.Detections.from_ultralytics(result)
        detections = detections[np.isin(detections.class_id, selected_classes)]
           
        
        label_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK,
                                            text_position=sv.Position.CENTER,
                                            text_scale=anchor/1700)
    
        
        pixelate_annotator = sv.PixelateAnnotator(pixel_size=anchor/50)
        annotated_image = pixelate_annotator.annotate(
            scene=annotated_image.copy(),
            detections=detections
        )
        
        annotated_image = label_annotator.annotate(
            annotated_image,
            detections=detections
        )        
    
    return annotated_image[:, :, ::-1]


inputs = [
    gr.Image(type="filepath", label="Input Image"),
    gr.CheckboxGroup(["anus", "make_love", "nipple", "penis", "vagina"], label="Input Labels"),
]

outputs = gr.Image(type="filepath", label="Output Image")
title = "EraX NSFW V1.0 Models for NSFW detection"

examples = [
            ['demo/img_1.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]], \
            ['demo/img_2.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]], \
            ['demo/img_3.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]]
           ]
demo_app = gr.Interface(
    fn=yolov8_inference,
    inputs=inputs,
    outputs=outputs,
    title=title,
    examples=examples,
    cache_examples=True,
)
demo_app.launch(debug=True)