Sleepyriizi commited on
Commit
21a55e5
·
1 Parent(s): 5013d6a

✨ Two-stage AI detector (binary + generator classifier)

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. SuSy.pt +3 -0
  3. app.py +110 -8
  4. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SuSy.pt filter=lfs diff=lfs merge=lfs -text
SuSy.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa10fae300ee2742c7a373b6c3332c2595b461954b8f5616d2d382ef2751020e
3
+ size 50810392
app.py CHANGED
@@ -1,10 +1,112 @@
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- with gr.Blocks(fill_height=True) as demo:
4
- with gr.Sidebar():
5
- gr.Markdown("# Inference Provider")
6
- gr.Markdown("This Space showcases the haywoodsloan/ai-image-detector-deploy model, served by the hf-inference API. Sign in with your Hugging Face account to use this API.")
7
- button = gr.LoginButton("Sign in")
8
- gr.load("models/haywoodsloan/ai-image-detector-deploy", accept_token=button, provider="hf-inference")
9
-
10
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─── app.py ──────────────────────────────────────────────────────────────
2
  import gradio as gr
3
+ import torch, numpy as np
4
+ from PIL import Image
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+ from torchvision import transforms
7
+ from skimage.feature import graycomatrix, graycoprops
8
 
9
+ # ───── Binary (real-vs-AI) model ─────────────────────────────────────────
10
+ BIN_CKPT = "haywoodsloan/ai-image-detector-deploy"
11
+ bin_proc = AutoImageProcessor.from_pretrained(BIN_CKPT)
12
+ bin_model = AutoModelForImageClassification.from_pretrained(BIN_CKPT)
13
+ bin_model.eval()
14
+
15
+ # ───── Generator-classifier (“SuSy”) model ───────────────────────────────
16
+ susy_model = torch.jit.load("SuSy.pt")
17
+ susy_model.eval()
18
+ GEN_CLASSES = [
19
+ 'DALL·E 3', 'Stable Diffusion 1.x', 'MJ V5/V6',
20
+ 'MJ V1/V2', 'Stable Diffusion XL'
21
+ ] # “Authentic” intentionally omitted
22
+
23
+ # ─── Helper - classify with SuSy (patch-based) ───────────────────────────
24
+ def susy_predict(image: Image.Image, top_k_patches: int = 5, patch_size: int = 224):
25
+ w, h = image.size
26
+ npx, npy = w // patch_size, h // patch_size
27
+ if npx == 0 or npy == 0: # tiny images → resize once and exit early
28
+ image = image.resize((patch_size, patch_size), Image.LANCZOS)
29
+ npx = npy = 1
30
+
31
+ # split into patches ---------------------------------------------------
32
+ patches = np.zeros((npx * npy, patch_size, patch_size, 3), dtype=np.uint8)
33
+ for i in range(npx):
34
+ for j in range(npy):
35
+ x, y = i * patch_size, j * patch_size
36
+ patches[i * npy + j] = np.array(image.crop((x, y, x + patch_size, y + patch_size)))
37
+
38
+ # pick the most “textured” patches (GLCM contrast) ---------------------
39
+ scores = []
40
+ to_tensor = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
41
+ for p in patches:
42
+ g = to_tensor(Image.fromarray(p)).squeeze(0).numpy()
43
+ glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True)
44
+ scores.append(graycoprops(glcm, "contrast")[0, 0])
45
+
46
+ keep = patches[np.argsort(scores)[::-1][:top_k_patches]]
47
+ keep = torch.from_numpy(keep.transpose(0, 3, 1, 2)).float() / 255.0
48
+
49
+ # predict --------------------------------------------------------------
50
+ susy_model.eval()
51
+ with torch.no_grad():
52
+ probs = susy_model(keep).softmax(dim=-1).mean(dim=0).numpy()
53
+
54
+ # “Authentic” is index 0 – skip it
55
+ out = {cls: float(p) for cls, p in zip(GEN_CLASSES, probs[1:])}
56
+ return dict(sorted(out.items(), key=lambda x: x[1], reverse=True))
57
+
58
+ # ─── Two-stage pipeline ***************************************************
59
+ def pipeline(img):
60
+ if isinstance(img, np.ndarray): # (inference images arrive as ndarrays)
61
+ img = Image.fromarray(img)
62
+
63
+ # Stage 1 – binary -----------------------------------------------------
64
+ with torch.no_grad():
65
+ inp = bin_proc(images=img, return_tensors="pt")
66
+ logits = bin_model(**inp).logits
67
+ p = torch.softmax(logits, dim=-1)[0]
68
+ pred_idx = int(p.argmax())
69
+ label = bin_model.config.id2label[pred_idx]
70
+ conf = float(p[pred_idx])
71
+
72
+ if label.lower() in {"ai", "fake", "synthetic", "generated"}:
73
+ # Stage 2 – generator detection -----------------------------------
74
+ gen_probs = susy_predict(img)
75
+ return {
76
+ "binary": f"AI-generated ({conf*100:.1f} %)",
77
+ "progress": "", # hide progress line
78
+ "generator": gen_probs,
79
+ }
80
+ else:
81
+ return {
82
+ "binary": f"Authentic ({conf*100:.1f} %)",
83
+ "progress": "",
84
+ "generator": None, # hides second label
85
+ }
86
+
87
+ # ─── Gradio UI ************************************************************
88
+ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
89
+ gr.Markdown("## 🖼️ AI Fake Detector")
90
+ gr.Markdown(
91
+ "1. Upload an image → we say **Real** or **AI**.<br>"
92
+ "2. If it’s AI-generated, we guess the **likely generator model**."
93
+ )
94
+
95
+ img_in = gr.Image(type="numpy", label="Input image")
96
+ run_btn = gr.Button("��� Run detection")
97
+ bin_out = gr.Label(label="Step 1 — Real vs AI")
98
+ progress = gr.Markdown(visible=False) # not used in final version
99
+ gen_out = gr.Label(label="Step 2 — Probable generator", visible=False)
100
+
101
+ def _on_click(img):
102
+ result = pipeline(img)
103
+ show_gen = result["generator"] is not None
104
+ return (
105
+ result["binary"],
106
+ gr.update(value=result["progress"], visible=False),
107
+ gr.update(value=result["generator"] or {}, visible=show_gen),
108
+ )
109
+
110
+ run_btn.click(_on_click, inputs=img_in, outputs=[bin_out, progress, gen_out])
111
+
112
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.26.0
2
+ torch # will pull torchvision too
3
+ transformers
4
+ pillow
5
+ scikit-image