EraX-NSFW-V1.0 / app.py
erax's picture
Update app.py
d2b29d9 verified
raw
history blame
3.39 kB
import gradio as gr
import spaces
import torch
from ultralytics import YOLO
from PIL import Image
import supervision as sv
import numpy as np
@spaces.GPU
def yolov8_inference(
image,
image_size,
conf_threshold,
iou_threshold,
):
"""
YOLOv8 inference function
Args:
image: Input image
model_path: Path to the model
image_size: Image size
conf_threshold: Confidence threshold
iou_threshold: IOU threshold
Returns:
Rendered image
"""
model = YOLO('erax_nsfw_v1.onnx')
# set model parameters
model.overrides['conf'] = conf_threshold # NMS confidence threshold
model.overrides['iou'] = iou_threshold # NMS IoU threshold
model.overrides['agnostic_nms'] = False # NMS class-agnostic
model.overrides['max_det'] = 1000 # maximum number of detections per image
results = model([image])
for result in results:
annotated_image = result.orig_img.copy()
h, w = annotated_image.shape[:2]
anchor = h if h > w else w
# make_love class will cover entire context !!!
# selected_classes = [0, 1, 2, 3, 4, 5] # all classes
selected_classes = [0, 2, 3, 4, 5] # hidden make_love class
detections = sv.Detections.from_ultralytics(result)
detections = detections[np.isin(detections.class_id, selected_classes)]
# box_annotator = sv.BoxAnnotator()
# annotated_image = box_annotator.annotate(
# annotated_image,
# detections=detections
# )
# blur_annotator = sv.BlurAnnotator(kernel_size=anchor/50)
# annotated_image = blur_annotator.annotate(
# annotated_image.copy(),
# detections=detections
# )
label_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK,
text_position=sv.Position.CENTER,
text_scale=anchor/1700)
annotated_image = label_annotator.annotate(
annotated_image,
detections=detections
)
pixelate_annotator = sv.PixelateAnnotator(pixel_size=anchor/50)
annotated_image = pixelate_annotator.annotate(
scene=annotated_image.copy(),
detections=detections
)
# sv.plot_image(annotated_image, size=(10, 10))
# results = model.predict(image, imgsz=image_size)
# render = render_result(model=model, image=image, result=results[0])
return annotated_image[:, :, ::-1]
inputs = [
gr.Image(type="filepath", label="Input Image"),
gr.Slider(minimum=320, maximum=1280, value=640, step=320, label="Image Size"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
]
outputs = gr.Image(type="filepath", label="Output Image")
title = "State-of-the-Art YOLO Models for Object detection"
# examples = [['demo_01.jpg', 'yolov8n', 640, 0.25, 0.45], ['demo_02.jpg', 'yolov8l', 640, 0.25, 0.45], ['demo_03.jpg', 'yolov8x', 1280, 0.25, 0.45]]
demo_app = gr.Interface(
fn=yolov8_inference,
inputs=inputs,
outputs=outputs,
title=title,
# examples=examples,
# cache_examples=True,
)
demo_app.launch(debug=True)