Shiwanni commited on
Commit
87bc3c0
·
verified ·
1 Parent(s): 26fa5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -70
app.py CHANGED
@@ -1,31 +1,34 @@
1
  import gradio as gr
2
  import torch
 
 
3
  from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
 
 
 
5
 
6
- # Models specialized in fake detection
7
  MODEL_NAMES = {
8
- "Face Fake Detector": "elgeish/cvpr2023-deepfake-detection",
9
- "General Fake Detector": "dima806/deepfake_vs_real_image_detection",
10
- "AI-Generated Detector": "rizvandwiki/gansfake-detector"
11
  }
12
 
13
  # Initialize models
14
  models = {}
15
  processors = {}
16
 
17
- print("Loading models...")
18
  for name, path in MODEL_NAMES.items():
19
  try:
20
  processors[name] = AutoImageProcessor.from_pretrained(path)
21
  models[name] = AutoModelForImageClassification.from_pretrained(path)
22
- print(f"Loaded: {name}")
23
- except Exception as e:
24
- print(f"Error loading {name}: {str(e)}")
25
 
26
  def analyze_image(image, selected_model):
27
  if image is None:
28
- return "Please upload an image first", None
29
 
30
  try:
31
  # Convert to RGB if needed
@@ -37,104 +40,258 @@ def analyze_image(image, selected_model):
37
  processor = processors.get(selected_model)
38
 
39
  if not model or not processor:
40
- return "Selected model not available", None
41
 
42
- # Preprocess and predict
43
  inputs = processor(images=image, return_tensors="pt")
 
 
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
 
47
- # Get predicted class
48
- predicted_class = model.config.id2label[torch.argmax(outputs.logits).item()]
 
 
 
 
49
 
50
- # Simple fake/real determination
51
- if any(word in predicted_class.lower() for word in ['fake', 'generated', 'ai', 'synthetic']):
52
- return "🛑 FAKE IMAGE DETECTED", "fake"
53
- else:
54
- return "✅ REAL IMAGE", "real"
55
 
56
  except Exception as e:
57
- return f"Error: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Custom CSS for simple interface
60
  custom_css = """
61
  :root {
62
- --real-color: #4CAF50;
63
- --fake-color: #F44336;
 
 
 
64
  }
65
 
66
  #main-container {
67
- max-width: 800px;
68
  margin: auto;
69
- padding: 20px;
 
 
 
70
  }
71
 
72
- .real-result {
73
- color: var(--real-color);
74
- font-size: 24px;
75
- font-weight: bold;
76
  text-align: center;
77
- padding: 20px;
78
- border: 3px solid var(--real-color);
79
- border-radius: 10px;
 
 
80
  }
81
 
82
- .fake-result {
83
- color: var(--fake-color);
84
- font-size: 24px;
85
- font-weight: bold;
86
- text-align: center;
 
 
 
 
 
 
 
 
87
  padding: 20px;
88
- border: 3px solid var(--fake-color);
89
- border-radius: 10px;
90
- animation: pulse 1s infinite;
 
 
 
91
  }
92
 
93
- @keyframes pulse {
94
- 0% { opacity: 0.8; }
95
- 50% { opacity: 1; }
96
- 100% { opacity: 0.8; }
 
 
 
 
 
 
 
 
 
 
 
97
  }
98
 
99
- .upload-box {
100
- border: 2px dashed #666;
101
- padding: 30px;
 
 
 
 
 
102
  text-align: center;
103
- border-radius: 10px;
104
- margin-bottom: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
106
  """
107
 
108
- with gr.Blocks(css=custom_css) as demo:
109
  with gr.Column(elem_id="main-container"):
110
- gr.Markdown("# 🔍 Fake Image Detector")
111
- gr.Markdown("Upload an image to check if it's real or fake")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- with gr.Column(elem_classes=["upload-box"]):
114
- image_input = gr.Image(type="pil", label="Upload Image")
115
- model_selector = gr.Dropdown(
116
- choices=list(MODEL_NAMES.keys()),
117
- value=list(MODEL_NAMES.keys())[0],
118
- label="Select Detection Model"
119
  )
120
- analyze_btn = gr.Button("Check Image", variant="primary")
121
 
122
- # Will be updated with the appropriate class
123
- result_output = gr.Label(label="Result", elem_classes=["result-placeholder"])
124
-
125
- def update_result_class(result, status):
126
- if status == "fake":
127
- return gr.Label.update(value=result, elem_classes=["fake-result"])
128
- elif status == "real":
129
- return gr.Label.update(value=result, elem_classes=["real-result"])
130
- else:
131
- return gr.Label.update(value=result)
 
 
 
132
 
 
133
  analyze_btn.click(
134
  fn=analyze_image,
135
  inputs=[image_input, model_selector],
136
- outputs=[result_output]
137
  )
138
 
139
  if __name__ == "__main__":
140
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
+ import cv2
8
+ from skimage import exposure
9
+ import time
10
 
11
+ # Load models (using free Hugging Face models)
12
  MODEL_NAMES = {
13
+ "Model 1": "dima806/deepfake_vs_real_image_detection",
14
+ "Model 2": "saltacc/anime-ai-detect",
15
+ "Model 3": "rizvandwiki/gansfake-detector"
16
  }
17
 
18
  # Initialize models
19
  models = {}
20
  processors = {}
21
 
 
22
  for name, path in MODEL_NAMES.items():
23
  try:
24
  processors[name] = AutoImageProcessor.from_pretrained(path)
25
  models[name] = AutoModelForImageClassification.from_pretrained(path)
26
+ except:
27
+ print(f"Could not load model: {name}")
 
28
 
29
  def analyze_image(image, selected_model):
30
  if image is None:
31
+ return None, None, "Please upload an image first", None
32
 
33
  try:
34
  # Convert to RGB if needed
 
40
  processor = processors.get(selected_model)
41
 
42
  if not model or not processor:
43
+ return None, None, "Selected model not available", None
44
 
45
+ # Preprocess image
46
  inputs = processor(images=image, return_tensors="pt")
47
+
48
+ # Predict
49
  with torch.no_grad():
50
  outputs = model(**inputs)
51
 
52
+ # Get probabilities
53
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
54
+
55
+ # Create visualizations
56
+ heatmap = generate_heatmap(image, model, processor)
57
+ chart_fig = create_probability_chart(probs, model.config.id2label)
58
 
59
+ # Format results
60
+ result_text = format_results(probs, model.config.id2label)
61
+
62
+ return heatmap, chart_fig, result_text, create_model_info(selected_model)
 
63
 
64
  except Exception as e:
65
+ return None, None, f"Error: {str(e)}", None
66
+
67
+ def generate_heatmap(image, model, processor):
68
+ """Generate a heatmap showing important regions for the prediction"""
69
+ # Convert to numpy array
70
+ img_array = np.array(image)
71
+
72
+ # Create a saliency map (simple version)
73
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
74
+ blurred = cv2.GaussianBlur(gray, (21, 21), 0)
75
+ heatmap = cv2.applyColorMap(blurred, cv2.COLORMAP_JET)
76
+
77
+ # Blend with original image
78
+ heatmap = cv2.addWeighted(img_array, 0.7, heatmap, 0.3, 0)
79
+
80
+ # Convert back to PIL
81
+ return Image.fromarray(heatmap)
82
+
83
+ def create_probability_chart(probs, id2label):
84
+ """Create a bar chart of class probabilities"""
85
+ labels = [id2label[i] for i in range(len(probs))]
86
+ colors = ['#4CAF50' if 'real' in label.lower() else '#F44336' for label in labels]
87
+
88
+ fig, ax = plt.subplots(figsize=(8, 4))
89
+ bars = ax.barh(labels, probs.numpy(), color=colors)
90
+ ax.set_xlim(0, 1)
91
+ ax.set_title('Detection Probabilities', pad=20)
92
+ ax.set_xlabel('Probability')
93
+
94
+ # Add value labels
95
+ for bar in bars:
96
+ width = bar.get_width()
97
+ ax.text(width + 0.02, bar.get_y() + bar.get_height()/2,
98
+ f'{width:.2f}',
99
+ va='center')
100
+
101
+ plt.tight_layout()
102
+ return fig
103
+
104
+ def format_results(probs, id2label):
105
+ """Format the results as text"""
106
+ results = []
107
+ for i, prob in enumerate(probs):
108
+ results.append(f"{id2label[i]}: {prob*100:.1f}%")
109
+
110
+ max_prob = max(probs)
111
+ max_class = id2label[torch.argmax(probs).item()]
112
+
113
+ if 'real' in max_class.lower():
114
+ conclusion = f"Conclusion: This image appears to be AUTHENTIC with {max_prob*100:.1f}% confidence"
115
+ else:
116
+ conclusion = f"Conclusion: This image appears to be FAKE/GENERATED with {max_prob*100:.1f}% confidence"
117
+
118
+ return "\n".join([conclusion, "", "Detailed probabilities:"] + results)
119
+
120
+ def create_model_info(model_name):
121
+ """Create information about the current model"""
122
+ info = {
123
+ "Model 1": "Trained to detect deepfakes vs real human faces",
124
+ "Model 2": "Specialized in detecting AI-generated anime images",
125
+ "Model 3": "General GAN-generated image detector"
126
+ }
127
+ return info.get(model_name, "No information available about this model")
128
 
129
+ # Custom CSS for the interface
130
  custom_css = """
131
  :root {
132
+ --primary: #4b6cb7;
133
+ --secondary: #182848;
134
+ --authentic: #4CAF50;
135
+ --fake: #F44336;
136
+ --neutral: #2196F3;
137
  }
138
 
139
  #main-container {
140
+ max-width: 1200px;
141
  margin: auto;
142
+ padding: 25px;
143
+ border-radius: 15px;
144
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
145
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
146
  }
147
 
148
+ .header {
 
 
 
149
  text-align: center;
150
+ margin-bottom: 25px;
151
+ background: linear-gradient(90deg, var(--primary) 0%, var(--secondary) 100%);
152
+ -webkit-background-clip: text;
153
+ -webkit-text-fill-color: transparent;
154
+ padding: 10px;
155
  }
156
 
157
+ .upload-area {
158
+ border: 3px dashed var(--primary) !important;
159
+ min-height: 300px;
160
+ border-radius: 12px !important;
161
+ transition: all 0.3s ease;
162
+ }
163
+
164
+ .upload-area:hover {
165
+ border-color: var(--secondary) !important;
166
+ transform: translateY(-2px);
167
+ }
168
+
169
+ .result-box {
170
  padding: 20px;
171
+ border-radius: 12px;
172
+ margin-top: 20px;
173
+ font-size: 1.1em;
174
+ transition: all 0.3s ease;
175
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
176
+ background: white;
177
  }
178
 
179
+ .visualization-box {
180
+ border-radius: 12px;
181
+ padding: 15px;
182
+ background: white;
183
+ margin-top: 15px;
184
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
185
+ }
186
+
187
+ .btn-primary {
188
+ background: linear-gradient(90deg, var(--primary) 0%, var(--secondary) 100%) !important;
189
+ color: white !important;
190
+ border: none !important;
191
+ padding: 12px 24px !important;
192
+ border-radius: 8px !important;
193
+ font-weight: bold !important;
194
  }
195
 
196
+ .model-select {
197
+ background: white !important;
198
+ border: 2px solid var(--primary) !important;
199
+ border-radius: 8px !important;
200
+ padding: 8px 12px !important;
201
+ }
202
+
203
+ .footer {
204
  text-align: center;
205
+ margin-top: 20px;
206
+ font-size: 0.9em;
207
+ color: #666;
208
+ }
209
+
210
+ @keyframes fadeIn {
211
+ from { opacity: 0; transform: translateY(10px); }
212
+ to { opacity: 1; transform: translateY(0); }
213
+ }
214
+
215
+ .animation {
216
+ animation: fadeIn 0.5s ease-in-out;
217
+ }
218
+
219
+ .loading {
220
+ animation: pulse 1.5s infinite;
221
+ }
222
+
223
+ @keyframes pulse {
224
+ 0% { opacity: 0.6; }
225
+ 50% { opacity: 1; }
226
+ 100% { opacity: 0.6; }
227
  }
228
  """
229
 
230
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
231
  with gr.Column(elem_id="main-container"):
232
+ with gr.Column(elem_classes=["header"]):
233
+ gr.Markdown("# 🛡️ DeepGuard AI")
234
+ gr.Markdown("## Advanced Deepfake Detection System")
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=1.5):
238
+ image_input = gr.Image(
239
+ type="pil",
240
+ label="Upload Image for Analysis",
241
+ elem_classes=["upload-area", "animation"]
242
+ )
243
+
244
+ with gr.Row():
245
+ model_selector = gr.Dropdown(
246
+ choices=list(MODEL_NAMES.keys()),
247
+ value=list(MODEL_NAMES.keys())[0],
248
+ label="Select Detection Model",
249
+ elem_classes=["model-select", "animation"]
250
+ )
251
+ analyze_btn = gr.Button(
252
+ "Analyze Image",
253
+ elem_classes=["btn-primary", "animation"]
254
+ )
255
+
256
+ with gr.Column(scale=1):
257
+ with gr.Column(elem_classes=["visualization-box"]):
258
+ heatmap_output = gr.Image(
259
+ label="Attention Heatmap",
260
+ interactive=False
261
+ )
262
+
263
+ with gr.Column(elem_classes=["visualization-box"]):
264
+ chart_output = gr.Plot(
265
+ label="Detection Probabilities"
266
+ )
267
 
268
+ with gr.Column(elem_classes=["result-box", "animation"]):
269
+ result_output = gr.Textbox(
270
+ label="Analysis Results",
271
+ interactive=False,
272
+ lines=8
 
273
  )
 
274
 
275
+ with gr.Column(elem_classes=["result-box", "animation"]):
276
+ model_info = gr.Textbox(
277
+ label="Model Information",
278
+ interactive=False,
279
+ lines=3
280
+ )
281
+
282
+ gr.Markdown("""
283
+ <div class="footer">
284
+ *Note: This tool provides probabilistic estimates. Always verify important findings with additional methods.<br>
285
+ Models may produce false positives/negatives. Performance varies by image type and quality.*
286
+ </div>
287
+ """)
288
 
289
+ # Event handlers
290
  analyze_btn.click(
291
  fn=analyze_image,
292
  inputs=[image_input, model_selector],
293
+ outputs=[heatmap_output, chart_output, result_output, model_info]
294
  )
295
 
296
  if __name__ == "__main__":
297
+ demo.launch()