# pip uninstall nvidia_cublas_cu11
import sys
sys.path.append('..')
import os
os.system(f'pip install dlib')
import torch
import numpy as np
from PIL import Image
from torch.nn import functional as F
import gradio as gr
import models_vit
from util.datasets import build_dataset
import argparse
from engine_finetune import test_all
import dlib
from huggingface_hub import hf_hub_download
P = os.path.abspath(__file__)
FRAME_SAVE_PATH = os.path.join(P[:-6], 'frame')
CKPT_SAVE_PATH = os.path.join(P[:-6], 'checkpoints')
CKPT_LIST = ['DfD Fine-tuned Checkpoint',
'DiFF(FE) Fine-tuned Checkpoint',
'DiFF(FS) Fine-tuned Checkpoint',
'DiFF(I2I) Fine-tuned Checkpoint',
'DiFF(T2I) Fine-tuned Checkpoint']
CKPT_NAME = ['DfD-checkpoint-min_val_loss.pth',
'DiFF-checkpoint-min_val_loss_FE.pth',
'DiFF-checkpoint-min_val_loss_FS.pth',
'DiFF-checkpoint-min_val_loss_I2I.pth',
'DiFF-checkpoint-min_val_loss_T2I.pth']
os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=64, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--normalize_from_IMN', action='store_true',
help='cal mean and std from imagenet, else from pretrain datasets')
parser.set_defaults(normalize_from_IMN=True)
parser.add_argument('--apply_simple_augment', action='store_true',
help='apply simple data augment')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.75,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=1000, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.set_defaults(eval=True)
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
args = get_args_parser()
args = args.parse_args()
args.nb_classes = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models_vit.__dict__['vit_base_patch16'](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
def load_model(ckpt):
if ckpt=='None':
return gr.update()
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
if os.path.isfile(args.resume) == False:
hf_hub_download(local_dir=CKPT_SAVE_PATH,
repo_id='FSFM-C3/model_v1',
filename=ckpt)
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
return gr.update()
def get_boundingbox(face, width, height, minsize=None):
"""
From FF++:
https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py
Expects a dlib face to generate a quadratic bounding box.
:param face: dlib face class
:param width: frame width
:param height: frame height
:param cfg.face_scale: bounding box size multiplier to get a bigger face region
:param minsize: set minimum bounding box size
:return: x, y, bounding_box_size in opencv form
"""
x1 = face.left()
y1 = face.top()
x2 = face.right()
y2 = face.bottom()
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
if minsize:
if size_bb < minsize:
size_bb = minsize
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
# Check for out of bounds, x-y top left corner
x1 = max(int(center_x - size_bb // 2), 0)
y1 = max(int(center_y - size_bb // 2), 0)
# Check for too big bb size for given x, y
size_bb = min(width - x1, size_bb)
size_bb = min(height - y1, size_bb)
return x1, y1, size_bb
def extract_face(frame):
face_detector = dlib.get_frontal_face_detector()
image = np.array(frame.convert('RGB'))
faces = face_detector(image, 1)
if len(faces) > 0:
# For now only take the biggest face
face = faces[0]
# Face crop and rescale(follow FF++)
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
# Get the landmarks/parts for the face in box d only with the five key points
cropped_face = image[y:y + size, x:x + size]
# cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC)
return Image.fromarray(cropped_face)
else:
return None
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
interval = np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int)
return interval.tolist()
import cv2
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, device='cpu'):
"""
1) extract specific num of frames from videos in [1st(index 0) frame, last frame] with uniform sample interval
2) extract face from frame with specific enlarge size
"""
video_capture = cv2.VideoCapture(src_video)
total_frames = video_capture.get(7)
# extract from the 1st(index 0) frame
if num_frames is not None:
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames)
else:
frame_indices = range(int(total_frames))
for frame_index in frame_indices:
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = video_capture.read()
image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
img = extract_face(image)
if img == None:
continue
img = img.resize((224, 224), Image.BICUBIC)
if not ret:
continue
save_img_name = f"frame_{frame_index}.png"
img.save(os.path.join(dst_path, '0', save_img_name))
# cv2.imwrite(os.path.join(dst_path, '0', save_img_name), frame)
video_capture.release()
# cv2.destroyAllWindows()
def C3_video_detection(video):
model.to(device)
# extract frames
num_frames = 32
files = os.listdir(FRAME_SAVE_PATH)
num_files = len(files)
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
os.makedirs(frame_path, exist_ok=True)
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames, device=device)
args.data_path = frame_path
args.batch_size = 32
dataset_val = build_dataset(is_train=False, args=args)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
frame_preds_list, video_y_pred_list = test_all(data_loader_val, model, device)
return video_y_pred_list
def C3_image_detection(image):
model.to(device)
files = os.listdir(FRAME_SAVE_PATH)
num_files = len(files)
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
os.makedirs(frame_path, exist_ok=True)
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
save_img_name = f"frame_0.png"
img = extract_face(image)
if img == None:
return ['Invalid Input']
img = img.resize((224, 224), Image.BICUBIC)
img.save(os.path.join(frame_path, '0', save_img_name))
args.data_path = frame_path
args.batch_size = 1
dataset_val = build_dataset(is_train=False, args=args)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
frame_preds_list, video_y_pred_list = test_all(data_loader_val, model, device)
return video_y_pred_list
# WebUI
with gr.Blocks() as demo:
gr.Markdown("# FSFM-C3 Demo")
gr.Markdown(
"This is a demo for deepfake detection using FSFM-C3 model."
)
gr.Markdown(
"Please provide images or video (<100s) for detection:"
)
with gr.Column():
ckpt_select_dropdown = gr.Dropdown(
label = "Setect Checkpoint to Use",
choices = ['None'] + CKPT_NAME,
multiselect = False,
value = 'None',
interactive = True,
)
with gr.Row():
with gr.Column(scale=5):
gr.Markdown(
"## Image Detection"
)
image = gr.Image(label="Upload your image", type="pil")
image_submit_btn = gr.Button("Submit")
output_results_image = gr.Textbox(label="Detection Result")
with gr.Column(scale=5):
gr.Markdown(
"## Video Detection"
)
video = gr.Video(label="Upload your video")
video_submit_btn = gr.Button("Submit")
output_results_video = gr.Textbox(label="Detection Result")
image_submit_btn.click(
fn=C3_image_detection,
inputs=[image],
outputs=[output_results_image],
)
video_submit_btn.click(
fn=C3_video_detection,
inputs=[video],
outputs=[output_results_video],
)
ckpt_select_dropdown.change(
fn=load_model,
inputs=[ckpt_select_dropdown],
outputs=[ckpt_select_dropdown],
)
if __name__ == "__main__":
gr.close_all()
demo.queue()
demo.launch()