Spaces:
Running
Running
Commit
·
21a55e5
1
Parent(s):
5013d6a
✨ Two-stage AI detector (binary + generator classifier)
Browse files- .gitattributes +1 -0
- SuSy.pt +3 -0
- app.py +110 -8
- requirements.txt +5 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
SuSy.pt filter=lfs diff=lfs merge=lfs -text
|
SuSy.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa10fae300ee2742c7a373b6c3332c2595b461954b8f5616d2d382ef2751020e
|
3 |
+
size 50810392
|
app.py
CHANGED
@@ -1,10 +1,112 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ─── app.py ──────────────────────────────────────────────────────────────
|
2 |
import gradio as gr
|
3 |
+
import torch, numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
6 |
+
from torchvision import transforms
|
7 |
+
from skimage.feature import graycomatrix, graycoprops
|
8 |
|
9 |
+
# ───── Binary (real-vs-AI) model ─────────────────────────────────────────
|
10 |
+
BIN_CKPT = "haywoodsloan/ai-image-detector-deploy"
|
11 |
+
bin_proc = AutoImageProcessor.from_pretrained(BIN_CKPT)
|
12 |
+
bin_model = AutoModelForImageClassification.from_pretrained(BIN_CKPT)
|
13 |
+
bin_model.eval()
|
14 |
+
|
15 |
+
# ───── Generator-classifier (“SuSy”) model ───────────────────────────────
|
16 |
+
susy_model = torch.jit.load("SuSy.pt")
|
17 |
+
susy_model.eval()
|
18 |
+
GEN_CLASSES = [
|
19 |
+
'DALL·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6',
|
20 |
+
'MJ V1/V2', 'Stable Diffusion XL'
|
21 |
+
] # “Authentic” intentionally omitted
|
22 |
+
|
23 |
+
# ─── Helper - classify with SuSy (patch-based) ───────────────────────────
|
24 |
+
def susy_predict(image: Image.Image, top_k_patches: int = 5, patch_size: int = 224):
|
25 |
+
w, h = image.size
|
26 |
+
npx, npy = w // patch_size, h // patch_size
|
27 |
+
if npx == 0 or npy == 0: # tiny images → resize once and exit early
|
28 |
+
image = image.resize((patch_size, patch_size), Image.LANCZOS)
|
29 |
+
npx = npy = 1
|
30 |
+
|
31 |
+
# split into patches ---------------------------------------------------
|
32 |
+
patches = np.zeros((npx * npy, patch_size, patch_size, 3), dtype=np.uint8)
|
33 |
+
for i in range(npx):
|
34 |
+
for j in range(npy):
|
35 |
+
x, y = i * patch_size, j * patch_size
|
36 |
+
patches[i * npy + j] = np.array(image.crop((x, y, x + patch_size, y + patch_size)))
|
37 |
+
|
38 |
+
# pick the most “textured” patches (GLCM contrast) ---------------------
|
39 |
+
scores = []
|
40 |
+
to_tensor = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
|
41 |
+
for p in patches:
|
42 |
+
g = to_tensor(Image.fromarray(p)).squeeze(0).numpy()
|
43 |
+
glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True)
|
44 |
+
scores.append(graycoprops(glcm, "contrast")[0, 0])
|
45 |
+
|
46 |
+
keep = patches[np.argsort(scores)[::-1][:top_k_patches]]
|
47 |
+
keep = torch.from_numpy(keep.transpose(0, 3, 1, 2)).float() / 255.0
|
48 |
+
|
49 |
+
# predict --------------------------------------------------------------
|
50 |
+
susy_model.eval()
|
51 |
+
with torch.no_grad():
|
52 |
+
probs = susy_model(keep).softmax(dim=-1).mean(dim=0).numpy()
|
53 |
+
|
54 |
+
# “Authentic” is index 0 – skip it
|
55 |
+
out = {cls: float(p) for cls, p in zip(GEN_CLASSES, probs[1:])}
|
56 |
+
return dict(sorted(out.items(), key=lambda x: x[1], reverse=True))
|
57 |
+
|
58 |
+
# ─── Two-stage pipeline ***************************************************
|
59 |
+
def pipeline(img):
|
60 |
+
if isinstance(img, np.ndarray): # (inference images arrive as ndarrays)
|
61 |
+
img = Image.fromarray(img)
|
62 |
+
|
63 |
+
# Stage 1 – binary -----------------------------------------------------
|
64 |
+
with torch.no_grad():
|
65 |
+
inp = bin_proc(images=img, return_tensors="pt")
|
66 |
+
logits = bin_model(**inp).logits
|
67 |
+
p = torch.softmax(logits, dim=-1)[0]
|
68 |
+
pred_idx = int(p.argmax())
|
69 |
+
label = bin_model.config.id2label[pred_idx]
|
70 |
+
conf = float(p[pred_idx])
|
71 |
+
|
72 |
+
if label.lower() in {"ai", "fake", "synthetic", "generated"}:
|
73 |
+
# Stage 2 – generator detection -----------------------------------
|
74 |
+
gen_probs = susy_predict(img)
|
75 |
+
return {
|
76 |
+
"binary": f"AI-generated ({conf*100:.1f} %)",
|
77 |
+
"progress": "", # hide progress line
|
78 |
+
"generator": gen_probs,
|
79 |
+
}
|
80 |
+
else:
|
81 |
+
return {
|
82 |
+
"binary": f"Authentic ({conf*100:.1f} %)",
|
83 |
+
"progress": "",
|
84 |
+
"generator": None, # hides second label
|
85 |
+
}
|
86 |
+
|
87 |
+
# ─── Gradio UI ************************************************************
|
88 |
+
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
89 |
+
gr.Markdown("## 🖼️ AI Fake Detector")
|
90 |
+
gr.Markdown(
|
91 |
+
"1. Upload an image → we say **Real** or **AI**.<br>"
|
92 |
+
"2. If it’s AI-generated, we guess the **likely generator model**."
|
93 |
+
)
|
94 |
+
|
95 |
+
img_in = gr.Image(type="numpy", label="Input image")
|
96 |
+
run_btn = gr.Button("��� Run detection")
|
97 |
+
bin_out = gr.Label(label="Step 1 — Real vs AI")
|
98 |
+
progress = gr.Markdown(visible=False) # not used in final version
|
99 |
+
gen_out = gr.Label(label="Step 2 — Probable generator", visible=False)
|
100 |
+
|
101 |
+
def _on_click(img):
|
102 |
+
result = pipeline(img)
|
103 |
+
show_gen = result["generator"] is not None
|
104 |
+
return (
|
105 |
+
result["binary"],
|
106 |
+
gr.update(value=result["progress"], visible=False),
|
107 |
+
gr.update(value=result["generator"] or {}, visible=show_gen),
|
108 |
+
)
|
109 |
+
|
110 |
+
run_btn.click(_on_click, inputs=img_in, outputs=[bin_out, progress, gen_out])
|
111 |
+
|
112 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.26.0
|
2 |
+
torch # will pull torchvision too
|
3 |
+
transformers
|
4 |
+
pillow
|
5 |
+
scikit-image
|