Spaces:
Running
on
A10G
Running
on
A10G
File size: 7,135 Bytes
d4fa03e dc03412 d4fa03e 70f7306 d4fa03e 80c9e4f d4fa03e 80c9e4f d4fa03e 11a0843 d4fa03e 6fac16d d4fa03e 70f7306 d4fa03e 97699fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# import all the libraries
import math
import numpy as np
import scipy
from PIL import Image
import torch
import torchvision.transforms as tforms
from diffusers import DiffusionPipeline, UNet2DConditionModel, DDIMScheduler, DDIMInverseScheduler
from diffusers.models import AutoencoderKL
import gradio as gr
# load SDXL pipeline
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained("mhdang/dpo-sdxl-text2image-v1", subfolder="unet", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", unet=unet, vae=vae, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
# watermarking helper functions. paraphrased from the reference impl of arXiv:2305.20030
def circle_mask(size=128, r=16, x_offset=0, y_offset=0):
x0 = y0 = size // 2
x0 += x_offset
y0 += y_offset
y, x = np.ogrid[:size, :size]
y = y[::-1]
return ((x - x0)**2 + (y-y0)**2)<= r**2
def get_pattern(shape, w_seed=999999):
g = torch.Generator(device=pipe.device)
g.manual_seed(w_seed)
gt_init = pipe.prepare_latents(1, pipe.unet.in_channels,
1024, 1024,
pipe.unet.dtype, pipe.device, g)
gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
# ring pattern. paper found this to be effective
gt_patch_tmp = gt_patch.clone().detach()
for i in range(shape[-1] // 2, 0, -1):
tmp_mask = circle_mask(gt_init.shape[-1], r=i)
tmp_mask = torch.tensor(tmp_mask)
for j in range(gt_patch.shape[1]):
gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
return gt_patch
def transform_img(image):
tform = tforms.Compose([tforms.Resize(1024),tforms.CenterCrop(1024),tforms.ToTensor()])
image = tform(image)
return 2.0 * image - 1.0
# hyperparameters
shape = (1, 4, 128, 128)
w_seed = 7433 # TREE :)
w_channel = 0
w_radius = 16 # the suggested r from section 4.4 of paper
# get w_key and w_mask
np_mask = circle_mask(shape[-1], r=w_radius)
torch_mask = torch.tensor(np_mask).to(pipe.device)
w_mask = torch.zeros(shape, dtype=torch.bool).to(pipe.device)
w_mask[:, w_channel] = torch_mask
w_key = get_pattern(shape, w_seed=w_seed).to(pipe.device)
def get_noise():
# moved w_key and w_mask to globals
# inject watermark
init_latents = pipe.prepare_latents(1, pipe.unet.in_channels,
1024, 1024,
pipe.unet.dtype, pipe.device, None)
init_latents_fft = torch.fft.fftshift(torch.fft.fft2(init_latents), dim=(-1, -2))
init_latents_fft[w_mask] = w_key[w_mask].clone()
init_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_fft, dim=(-1, -2))).real
# hot fix to prevent out of bounds values. will "properly" fix this later
init_latents[init_latents == float("Inf")] = 4
init_latents[init_latents == float("-Inf")] = -4
return init_latents
def detect(image):
# invert scheduler
curr_scheduler = pipe.scheduler
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
# ddim inversion
img = transform_img(image).unsqueeze(0).to(pipe.unet.dtype).to(pipe.device)
image_latents = pipe.vae.encode(img).latent_dist.mode() * 0.13025
inverted_latents = pipe(prompt="", latents=image_latents, guidance_scale=1, num_inference_steps=50, output_type="latent")
inverted_latents = inverted_latents.images
# calculate p-value instead of detection threshold. more rigorous, plus we can do a non-boolean output
inverted_latents_fft = torch.fft.fftshift(torch.fft.fft2(inverted_latents), dim=(-1, -2))[w_mask].flatten()
target = w_key[w_mask].flatten()
inverted_latents_fft = torch.concatenate([inverted_latents_fft.real, inverted_latents_fft.imag])
target = torch.concatenate([target.real, target.imag])
sigma = inverted_latents_fft.std()
lamda = (target ** 2 / sigma ** 2).sum().item()
x = (((inverted_latents_fft - target) / sigma) ** 2).sum().item()
p_value = scipy.stats.ncx2.cdf(x=x, df=len(target), nc=lamda)
# revert scheduler
pipe.scheduler = curr_scheduler
if p_value == 0:
return 1.0
else:
return max(0.0, 1-1/math.log(5/p_value,10))
def generate(prompt):
return pipe(prompt=prompt, negative_prompt="monochrome", num_inference_steps=50, latents=get_noise()).images[0]
# optimize for speed
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
print(detect(generate("an astronaut riding a green horse"))) # warmup after jit
# actual gradio demo
def manager(input, progress=gr.Progress(track_tqdm=True)): # to prevent the queue from overloading
if type(input) == str:
return generate(input)
elif type(input) == np.ndarray:
image = Image.fromarray(input)
percent = detect(image)
return {"watermarked": percent, "not_watermarked": 1.0-percent}
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green",secondary_hue="green", font=gr.themes.GoogleFont("Fira Sans"))) as app:
with gr.Row():
gr.HTML('<center><p>Bad actors are using generative AI to destroy the livelihoods of real artists. We need transparency now.</p><h1><span style="font-size:1.5em">Introducing Dendrokronos 🌳</span></h1></center>')
with gr.Row():
with gr.Column():
gr.Markdown("# Generate\nType a prompt and hit Go. Dendrokronos will generate an invisibly-watermarked image. \nYou can click the download button to save the finished image. Try it with the detector.")
with gr.Group():
with gr.Row():
gen_in = gr.Textbox(max_lines=1, placeholder='try "a majestic tree at sunset, oil painting"', show_label=False, scale=4)
gen_btn = gr.Button("Go", variant="primary", scale=0)
gen_out = gr.Image(interactive=False, show_label=False)
gen_btn.click(fn=manager, inputs=gen_in, outputs=gen_out)
with gr.Column():
gr.Markdown("# Detect\nUpload an image and hit Detect. Dendrokronos will predict the probability it was watermarked. \nNote: Dendrokronos can only detect its own watermark. It won't detect other AIs, such as DALL-E.")
det_out = gr.Label(show_label=False)
with gr.Group():
det_btn = gr.Button("Detect", variant="primary")
det_in = gr.Image(interactive=True, sources=["upload","clipboard"], show_label=False)
det_btn.click(fn=manager, inputs=det_in, outputs=det_out)
with gr.Row():
gr.HTML('<center><h1> </h1>Acknowledgements: Dendrokronos uses <a href="https://huggingface.co/mhdang/dpo-sdxl-text2image-v1">SDXL DPO 1.0</a> for the underlying image generation and <a href="https://arxiv.org/abs/2305.20030">an algorithm by UMD researchers</a> for the watermark technology.<br />Dendrokronos is a project by Devin Gulliver.</center>')
app.queue()
app.launch(show_api=False)
|