Sleepyriizi commited on
Commit
2898702
Β·
1 Parent(s): 38e5a7e

final CAM layers fixed

Browse files
Files changed (3) hide show
  1. 0.7 +0 -0
  2. app.py +97 -107
  3. requirements.txt +4 -1
0.7 ADDED
File without changes
app.py CHANGED
@@ -1,156 +1,146 @@
1
- # app.py ────────────────────────────────────────────────────────────────
2
  """
3
- Two‑stage local AI‑image detector
4
- 1. haywoodsloan/ai-image-detector-deploy β†’ Real vsΒ AI (Swin‑V2)
5
- 2. SuSy.pt β†’ Likely generator (ResNet‑based)
6
 
7
- Includes Grad‑CAM overlays:
8
- β€’ always show heat‑map for binary decision
9
- β€’ if image is flagged AI, also show heat‑map for SuSy
 
10
  """
11
 
12
- import gradio as gr
13
- import numpy as np, torch, pandas as pd, matplotlib.pyplot as plt
14
  from PIL import Image
15
  from torchvision import transforms
16
- from skimage.feature import graycomatrix, graycoprops
17
  from transformers import AutoImageProcessor, AutoModelForImageClassification
18
  from torchcam.methods import GradCAM
 
 
 
19
 
20
- # ──────────── Stage‑1 model (binary) ────────────────────────────────
21
- BIN_ID = "haywoodsloan/ai-image-detector-deploy"
22
- bin_proc = AutoImageProcessor.from_pretrained(BIN_ID)
23
- bin_model = AutoModelForImageClassification.from_pretrained(BIN_ID)
24
- bin_model.eval()
25
 
26
- CAM_LAYER_BIN = "encoder.layers.3.blocks.1.layernorm_after" # <- from dump
 
 
 
 
27
 
28
- # ──────────── Stage‑2 model (SuSy) ──────────────────────────────────
29
- susy_model = torch.jit.load("SuSy.pt").eval()
30
- CAM_LAYER_SUSY = "feature_extractor.resnet_model.layer4.1.relu" # <- from dump
31
 
32
- GEN_CLASSES = [
33
- "Stable Diffusion 1.x", "DALLΒ·E 3", "MJ V5/V6",
34
- "Stable Diffusion XL", "MJ V1/V2",
35
- ]
36
  PATCH, TOP = 224, 5
37
 
38
- # ──────────── Heat‑map helper ───────────────────────────────────────
39
- def grad_cam_overlay(model, inputs, target_layer, class_idx, orig_pil):
40
- # prepare Grad‑CAM extractor
41
- cam_ex = GradCAM(model, target_layer=target_layer,
42
- input_shape=next(iter(inputs.values()) if isinstance(inputs, dict) else [inputs]).shape)
43
-
44
- # forward & backward
45
- scores = model(**inputs).logits if isinstance(inputs, dict) else model(inputs)
46
- scores[0, class_idx].backward()
47
-
48
- # normalise cam
49
- mask = cam_ex(class_idx)[0].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-6)
51
- mask = Image.fromarray(np.uint8(plt.cm.jet(mask)[:, :, :3] * 255)).resize(orig_pil.size, Image.BICUBIC)
52
- return Image.blend(orig_pil.convert("RGBA"), mask.convert("RGBA"), alpha=0.45)
 
53
 
54
- # ──────────── SuSy helper ───────────────────────────────────────────
55
- to_tensor = transforms.ToTensor()
56
- to_gray_pil = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
57
 
58
- def susy_predict(img: Image.Image) -> dict:
59
  w, h = img.size
60
  npx, npy = max(1, w // PATCH), max(1, h // PATCH)
61
- patches = np.zeros((npx * npy, PATCH, PATCH, 3), dtype=np.uint8)
62
 
63
  for i in range(npx):
64
  for j in range(npy):
65
  x, y = i * PATCH, j * PATCH
66
- patches[i * npy + j] = np.array(img.crop((x, y, x + PATCH, y + PATCH)).resize((PATCH, PATCH)))
67
 
68
  contrasts = []
69
  for p in patches:
70
- g = to_gray_pil(Image.fromarray(p)).squeeze(0).numpy()
71
  glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True)
72
  contrasts.append(graycoprops(glcm, "contrast")[0, 0])
73
 
74
- idx = np.argsort(contrasts)[::-1][:TOP]
75
- tensor = torch.from_numpy(patches[idx].transpose(0, 3, 1, 2)).float() / 255.0
76
-
77
  with torch.no_grad():
78
- probs = susy_model(tensor).softmax(-1).mean(0).numpy()[1:]
79
  return dict(zip(GEN_CLASSES, probs))
80
 
81
- # ──────────── Main pipeline ─────────────────────────────────────────
82
  def pipeline(img_arr):
83
  img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr
84
  heatmaps = []
85
 
86
- # Stage‑1: binary Real/AI
87
  with torch.no_grad():
88
- inp_bin = bin_proc(images=img, return_tensors="pt")
89
- logits = bin_model(**inp_bin).logits
90
- probs = torch.softmax(logits, -1)[0].tolist() # [artificial, real]
91
-
92
- ai_conf, real_conf = probs[0], probs[1]
93
-
94
- # Grad‑CAM for winning class
95
- class_idx = 0 if ai_conf >= real_conf else 1
96
- heatmaps.append(
97
- grad_cam_overlay(
98
- bin_model, inp_bin,
99
- target_layer=CAM_LAYER_BIN,
100
- class_idx=class_idx,
101
- orig_pil=img
102
- )
103
- )
104
-
105
- # defaults
106
- msg, bar_df, bar_vis = f"Authentic ({real_conf*100:.1f} %)", None, False
107
-
108
- # Stage‑2 if AI
109
  if ai_conf > real_conf:
110
- msg = f"AI‑generated ({ai_conf*100:.1f} %)"
111
  gen_probs = susy_predict(img)
112
- bar_df = pd.DataFrame({"class": gen_probs.keys(), "prob": gen_probs.values()})
113
- bar_vis = True
114
 
115
- # SuSy heat‑map: choose most‑probable generator class
116
  with torch.no_grad():
117
- t_inp = to_tensor(img.resize((224, 224))).unsqueeze(0)
118
- logits_susy = susy_model(t_inp)
119
- susy_class = logits_susy[0, 1:].argmax().item() + 1 # skip 'real'
120
-
121
- heatmaps.append(
122
- grad_cam_overlay(
123
- susy_model, t_inp,
124
- target_layer=CAM_LAYER_SUSY,
125
- class_idx=susy_class,
126
- orig_pil=img
127
- )
128
- )
129
-
130
- return msg, gr.update(value=bar_df, visible=bar_vis), heatmaps
131
-
132
- # ──────────── Gradio UI ─────────────────────────────────────────────
133
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
- gr.Markdown("## πŸ–ΌοΈ Local AI Fake Detector")
135
 
 
 
 
136
  with gr.Row():
137
  img_in = gr.Image(type="numpy", label="Upload image")
138
  btn = gr.Button("Detect")
139
 
140
- txt_bin = gr.Textbox(label="StepΒ 1β€―β€”β€―Realβ€―vsβ€―AI", interactive=False)
141
-
142
- bar_gen = gr.BarPlot(
143
- x="class", y="prob",
144
- title="Stepβ€―2β€―β€”β€―Likely generator",
145
- y_label="probability",
146
- visible=False
147
- )
148
-
149
- gal_cam = gr.Gallery(
150
- label="Model attention heat‑maps",
151
- columns=2, height=300, visible=True
152
- )
153
 
154
- btn.click(pipeline, inputs=img_in, outputs=[txt_bin, bar_gen, gal_cam])
155
 
156
  demo.launch()
 
1
+ # app.py ───────────────────────────────────────────────────────────────
2
  """
3
+ Two‑stage AI‑image detector with visual explainability
 
 
4
 
5
+ Stage‑1 : haywoodsloan/ai-image-detector-deploy (Swin‑V2) β†’ RealΒ vsΒ AI
6
+ ⟳ Grad‑CAM (torchcam) overlay
7
+ Stage‑2 : SuSy.pt (torchscript ResNet) β†’ Generator
8
+ ⟳ Saliency‑grad overlay (Captum), because hooks are disabled
9
  """
10
 
11
+ # ───────────────────── Imports ────────────────────────────────────────
12
+ import torch, numpy as np, pandas as pd, matplotlib.pyplot as plt
13
  from PIL import Image
14
  from torchvision import transforms
 
15
  from transformers import AutoImageProcessor, AutoModelForImageClassification
16
  from torchcam.methods import GradCAM
17
+ from captum.attr import Saliency
18
+ from skimage.feature import graycomatrix, graycoprops
19
+ import gradio as gr
20
 
21
+ # ─────────────────── Runtime / models ─────────────────────────────────
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ plt.set_loglevel("ERROR")
 
 
24
 
25
+ # Stage‑1 (eager)
26
+ BIN_ID = "haywoodsloan/ai-image-detector-deploy"
27
+ bin_proc = AutoImageProcessor.from_pretrained(BIN_ID)
28
+ bin_mod = AutoModelForImageClassification.from_pretrained(BIN_ID).to(device).eval()
29
+ CAM_LAYER_BIN = "encoder.layers.3.blocks.1.layernorm_after"
30
 
31
+ # Stage‑2 (scripted)
32
+ susy_mod = torch.jit.load("SuSy.pt").to(device).eval() # ScriptModule
33
+ CAM_LAYER_SUSY = "feature_extractor.resnet_model.layer4.1.relu"
34
 
35
+ GEN_CLASSES = ["Stable Diffusion 1.x", "DALLΒ·E 3",
36
+ "MJ V5/V6", "Stable Diffusion XL", "MJ V1/V2"]
 
 
37
  PATCH, TOP = 224, 5
38
 
39
+ # ─────────────── Universal overlay helper ─────────────────────────────
40
+ def overlay_explanation(model, model_inputs, target_layer, class_idx, base_img):
41
+ """
42
+ β€’ If model is eager (supports hooks) β†’ Grad‑CAM via torchcam
43
+ β€’ If model is ScriptModule β†’ absolute‑gradient saliency via Captum
44
+ Returns an RGBA PIL image blended with the heat‑map.
45
+ """
46
+ is_script = isinstance(model, torch.jit.ScriptModule)
47
+
48
+ # Prepare inputs for forward
49
+ forward_inputs = model_inputs if torch.is_tensor(model_inputs) else dict(model_inputs)
50
+
51
+ # ---------- Scripted: Captum Saliency ----------
52
+ if is_script:
53
+ model.zero_grad(set_to_none=True)
54
+ sal = Saliency(model)
55
+ if not torch.is_tensor(forward_inputs):
56
+ forward_inputs = forward_inputs["pixel_values"]
57
+ grads = sal.attribute(forward_inputs, target=class_idx).abs().mean(1, keepdim=True)
58
+ mask = grads.squeeze().detach().cpu().numpy()
59
+ # ---------- Eager: torchcam Grad‑CAM ----------
60
+ else:
61
+ mods = dict(model.named_modules())
62
+ tgt = mods.get(target_layer) or next(m for n, m in mods.items() if n.endswith(target_layer))
63
+ cam = GradCAM(model, target_layer=tgt)
64
+ outputs = (model(forward_inputs) if torch.is_tensor(forward_inputs)
65
+ else model(**forward_inputs))
66
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
67
+ mask = cam(class_idx, logits)[0].detach().cpu().numpy()
68
+
69
+ # normalise & overlay
70
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-6)
71
+ heat = Image.fromarray((plt.cm.jet(mask)[:, :, :3] * 255).astype(np.uint8))\
72
+ .resize(base_img.size, Image.BICUBIC)
73
+ return Image.blend(base_img.convert("RGBA"), heat.convert("RGBA"), alpha=0.45)
74
 
75
+ # ───────────── SuSy patch‑ranking helper ──────────────────────────────
76
+ to_tensor = transforms.ToTensor()
77
+ to_gray = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
78
 
79
+ def susy_predict(img: Image.Image):
80
  w, h = img.size
81
  npx, npy = max(1, w // PATCH), max(1, h // PATCH)
82
+ patches = np.zeros((npx * npy, PATCH, PATCH, 3), dtype=np.uint8)
83
 
84
  for i in range(npx):
85
  for j in range(npy):
86
  x, y = i * PATCH, j * PATCH
87
+ patches[i*npy+j] = np.array(img.crop((x, y, x+PATCH, y+PATCH)).resize((PATCH, PATCH)))
88
 
89
  contrasts = []
90
  for p in patches:
91
+ g = to_gray(Image.fromarray(p)).squeeze(0).numpy()
92
  glcm = graycomatrix(g, [5], [0], 256, symmetric=True, normed=True)
93
  contrasts.append(graycoprops(glcm, "contrast")[0, 0])
94
 
95
+ idx = np.argsort(contrasts)[::-1][:TOP]
96
+ tens = torch.from_numpy(patches[idx].transpose(0,3,1,2)).float()/255.0
 
97
  with torch.no_grad():
98
+ probs = susy_mod(tens.to(device)).softmax(-1).mean(0).cpu().numpy()[1:]
99
  return dict(zip(GEN_CLASSES, probs))
100
 
101
+ # ───────────────────── Pipeline ───────────────────────────────────────
102
  def pipeline(img_arr):
103
  img = Image.fromarray(img_arr) if isinstance(img_arr, np.ndarray) else img_arr
104
  heatmaps = []
105
 
106
+ # Stage‑1
107
  with torch.no_grad():
108
+ inp_bin = bin_proc(images=img, return_tensors="pt").to(device)
109
+ logits = bin_mod(**inp_bin).logits.softmax(-1)[0] # [AI, Real]
110
+
111
+ ai_conf, real_conf = logits
112
+ winner_idx = 0 if ai_conf >= real_conf else 1
113
+ heatmaps.append(overlay_explanation(bin_mod, inp_bin, CAM_LAYER_BIN, winner_idx, img))
114
+
115
+ verdict = f"Authentic ({real_conf*100:.1f} %)"
116
+ bar_df, show_bar = None, False
117
+
118
+ # Stage‑2 (only if AI)
 
 
 
 
 
 
 
 
 
 
119
  if ai_conf > real_conf:
120
+ verdict = f"AI‑generated ({ai_conf*100:.1f} %)"
121
  gen_probs = susy_predict(img)
122
+ bar_df = pd.DataFrame({"class": gen_probs.keys(), "prob": gen_probs.values()})
123
+ show_bar = True
124
 
 
125
  with torch.no_grad():
126
+ susy_in = to_tensor(img.resize((224,224))).unsqueeze(0).to(device)
127
+ g_idx = susy_mod(susy_in)[0,1:].argmax().item() + 1
128
+ heatmaps.append(overlay_explanation(susy_mod, susy_in, CAM_LAYER_SUSY, g_idx, img))
129
+
130
+ return verdict, gr.update(value=bar_df, visible=show_bar), heatmaps
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # ───────────────────────── UI ─────────────────────────────────────────
133
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
+ gr.Markdown("## πŸ–ΌοΈ Two‑Stage AI Fake DetectorΒ β€”Β Explained with Heat‑maps")
135
  with gr.Row():
136
  img_in = gr.Image(type="numpy", label="Upload image")
137
  btn = gr.Button("Detect")
138
 
139
+ txt_out = gr.Textbox(label="Verdict", interactive=False)
140
+ bar_out = gr.BarPlot(x="class", y="prob", title="Likely generator",
141
+ y_label="probability", visible=False)
142
+ gal_out = gr.Gallery(label="Heat‑maps", columns=2, height=320)
 
 
 
 
 
 
 
 
 
143
 
144
+ btn.click(pipeline, inputs=img_in, outputs=[txt_out, bar_out, gal_out])
145
 
146
  demo.launch()
requirements.txt CHANGED
@@ -8,4 +8,7 @@ pydantic==2.10.6
8
  wheel
9
  huggingface_hub>=0.22
10
  pandas
11
- torchcam>=0.4
 
 
 
 
8
  wheel
9
  huggingface_hub>=0.22
10
  pandas
11
+ torchcam>=0.4
12
+ matplotlib>=3.8
13
+ timm>=0.9.12
14
+ captum>=0.7