File size: 3,768 Bytes
c6034c4 8997b7a c6034c4 8997b7a c6034c4 11171c1 9bb8f6a c6034c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
from PIL import Image
from ultralytics import YOLO
import torchvision.transforms.functional as TVF
from transformers import Owlv2VisionModel
from torch import nn
import torch
import torch.nn.functional as F
# OWLv2 classification head
class DetectorModelOwl(nn.Module):
owl: Owlv2VisionModel
def __init__(self, model_path: str, dropout: float, n_hidden: int = 768):
super().__init__()
owl = Owlv2VisionModel.from_pretrained(model_path)
assert isinstance(owl, Owlv2VisionModel)
self.owl = owl
self.owl.requires_grad_(False)
self.transforms = None
self.dropout1 = nn.Dropout(dropout)
self.ln1 = nn.LayerNorm(n_hidden, eps=1e-5)
self.linear1 = nn.Linear(n_hidden, n_hidden * 2)
self.act1 = nn.GELU()
self.dropout2 = nn.Dropout(dropout)
self.ln2 = nn.LayerNorm(n_hidden * 2, eps=1e-5)
self.linear2 = nn.Linear(n_hidden * 2, 2)
def forward(self, pixel_values: torch.Tensor, labels: torch.Tensor | None = None):
with torch.autocast("cpu", dtype=torch.bfloat16):
# Embed the image
outputs = self.owl(pixel_values=pixel_values, output_hidden_states=True)
x = outputs.last_hidden_state # B, N, C
# Linear
x = self.dropout1(x)
x = self.ln1(x)
x = self.linear1(x)
x = self.act1(x)
# Norm and Mean
x = self.dropout2(x)
#x = x.mean(dim=1)
x, _ = x.max(dim=1)
x = self.ln2(x)
# Linear
x = self.linear2(x)
if labels is not None:
loss = F.cross_entropy(x, labels)
return (x, loss)
return (x,)
def owl_predict(image: Image.Image) -> bool:
# Process the image
# Pad to square
big_side = max(image.size)
new_image = Image.new("RGB", (big_side, big_side), (128, 128, 128))
new_image.paste(image, (0, 0))
# Resize to 960x960
preped = new_image.resize((960, 960), Image.BICUBIC) # Bicubic performed best in my tests (even compared to Lanczos)
#preped = new_image.resize((1008, 1008), Image.BICUBIC) # Bicubic performed best in my tests (even compared to Lanczos)
# Convert to tensor and normalize
preped = TVF.pil_to_tensor(preped)
preped = preped / 255.0
input_image = TVF.normalize(preped, [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
# Run
logits, = model(input_image.to('cpu').unsqueeze(0), None)
probs = F.softmax(logits, dim=1)
prediction = torch.argmax(probs.cpu(), dim=1)
return prediction.item() == 1
def yolo_predict(image: Image.Image) -> Image.Image:
results = yolo_model(image, imgsz=1024, augment=True, iou=0.5)
assert len(results) == 1
result = results[0]
im_array = result.plot()
im = Image.fromarray(im_array[..., ::-1])
return im
def predict(image: Image.Image, conf_threshold: float):
# OWLv2
owl_prediction = owl_predict(image)
label_owl = "Watermarked" if owl_prediction else "Not Watermarked"
# YOLO
yolo_image = yolo_predict(image)
return yolo_image, f"OWLv2 Prediction: {label_owl}"
# Load OWLv2 classification model
model = DetectorModelOwl("google/owlv2-base-patch16-ensemble", dropout=0.0)
model.load_state_dict(torch.load("far5y1y5-8000.pt", map_location="cpu"))
model.eval()
# Load YOLO model
yolo_model = YOLO("yolo11x-train28-best.pt")
gradio_app = gr.Blocks()
with gr.Blocks() as app:
gr.HTML(
"""
<h1>Watermark Detection</h1>
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="Image")
conf_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Confidence Threshold")
btn_submit = gr.Button(value="Detect Watermarks")
with gr.Column():
image_yolo = gr.Image(type="pil", label="YOLO Detections")
label_owl = gr.Label(label="OWLv2 Prediction: N/A")
btn_submit.click(fn=predict, inputs=[image, conf_threshold], outputs=[image_yolo, label_owl])
if __name__ == "__main__":
app.launch() |