File size: 7,135 Bytes
d4fa03e
 
 
 
 
 
 
dc03412
d4fa03e
 
 
 
 
70f7306
 
d4fa03e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c9e4f
d4fa03e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c9e4f
d4fa03e
11a0843
 
 
 
d4fa03e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fac16d
d4fa03e
 
 
 
 
 
 
 
 
 
 
70f7306
d4fa03e
 
97699fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# import all the libraries
import math
import numpy as np
import scipy
from PIL import Image
import torch
import torchvision.transforms as tforms
from diffusers import DiffusionPipeline, UNet2DConditionModel, DDIMScheduler, DDIMInverseScheduler
from diffusers.models import AutoencoderKL
import gradio as gr

# load SDXL pipeline
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained("mhdang/dpo-sdxl-text2image-v1", subfolder="unet", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", unet=unet, vae=vae, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

# watermarking helper functions. paraphrased from the reference impl of arXiv:2305.20030

def circle_mask(size=128, r=16, x_offset=0, y_offset=0):
    x0 = y0 = size // 2
    x0 += x_offset
    y0 += y_offset
    y, x = np.ogrid[:size, :size]
    y = y[::-1]
    return ((x - x0)**2 + (y-y0)**2)<= r**2

def get_pattern(shape, w_seed=999999):
    g = torch.Generator(device=pipe.device)
    g.manual_seed(w_seed)
    gt_init = pipe.prepare_latents(1, pipe.unet.in_channels,
                                   1024, 1024,
                                   pipe.unet.dtype, pipe.device, g)
    gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
    # ring pattern. paper found this to be effective
    gt_patch_tmp = gt_patch.clone().detach()
    for i in range(shape[-1] // 2, 0, -1):
        tmp_mask = circle_mask(gt_init.shape[-1], r=i)
        tmp_mask = torch.tensor(tmp_mask)
        for j in range(gt_patch.shape[1]):
            gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()

    return gt_patch

def transform_img(image):
    tform = tforms.Compose([tforms.Resize(1024),tforms.CenterCrop(1024),tforms.ToTensor()])
    image = tform(image)
    return 2.0 * image - 1.0

# hyperparameters
shape = (1, 4, 128, 128)
w_seed = 7433 # TREE :)
w_channel = 0
w_radius = 16 # the suggested r from section 4.4 of paper

# get w_key and w_mask
np_mask = circle_mask(shape[-1], r=w_radius)
torch_mask = torch.tensor(np_mask).to(pipe.device)
w_mask = torch.zeros(shape, dtype=torch.bool).to(pipe.device)
w_mask[:, w_channel] = torch_mask
w_key = get_pattern(shape, w_seed=w_seed).to(pipe.device)


def get_noise():
    # moved w_key and w_mask to globals

    # inject watermark
    init_latents = pipe.prepare_latents(1, pipe.unet.in_channels,
                                        1024, 1024,
                                        pipe.unet.dtype, pipe.device, None)
    init_latents_fft = torch.fft.fftshift(torch.fft.fft2(init_latents), dim=(-1, -2))
    init_latents_fft[w_mask] = w_key[w_mask].clone()
    init_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_fft, dim=(-1, -2))).real
    # hot fix to prevent out of bounds values. will "properly" fix this later
    init_latents[init_latents == float("Inf")] = 4
    init_latents[init_latents == float("-Inf")] = -4

    return init_latents

def detect(image):
    # invert scheduler
    curr_scheduler = pipe.scheduler
    pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)

    # ddim inversion
    img = transform_img(image).unsqueeze(0).to(pipe.unet.dtype).to(pipe.device)
    image_latents = pipe.vae.encode(img).latent_dist.mode() * 0.13025
    inverted_latents = pipe(prompt="", latents=image_latents, guidance_scale=1, num_inference_steps=50, output_type="latent")
    inverted_latents = inverted_latents.images

    # calculate p-value instead of detection threshold. more rigorous, plus we can do a non-boolean output
    inverted_latents_fft = torch.fft.fftshift(torch.fft.fft2(inverted_latents), dim=(-1, -2))[w_mask].flatten()
    target = w_key[w_mask].flatten()
    inverted_latents_fft = torch.concatenate([inverted_latents_fft.real, inverted_latents_fft.imag])
    target = torch.concatenate([target.real, target.imag])

    sigma = inverted_latents_fft.std()
    lamda = (target ** 2 / sigma ** 2).sum().item()
    x = (((inverted_latents_fft - target) / sigma) ** 2).sum().item()
    p_value = scipy.stats.ncx2.cdf(x=x, df=len(target), nc=lamda)

    # revert scheduler
    pipe.scheduler = curr_scheduler

    if p_value == 0:
        return 1.0
    else:
        return max(0.0, 1-1/math.log(5/p_value,10))

def generate(prompt):
    return pipe(prompt=prompt, negative_prompt="monochrome", num_inference_steps=50, latents=get_noise()).images[0]

# optimize for speed
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
print(detect(generate("an astronaut riding a green horse"))) # warmup after jit

# actual gradio demo

def manager(input, progress=gr.Progress(track_tqdm=True)): # to prevent the queue from overloading
    if type(input) == str:
        return generate(input)
    elif type(input) == np.ndarray:
        image = Image.fromarray(input)
        percent = detect(image)
        return {"watermarked": percent, "not_watermarked": 1.0-percent}

with gr.Blocks(theme=gr.themes.Soft(primary_hue="green",secondary_hue="green", font=gr.themes.GoogleFont("Fira Sans"))) as app:
    with gr.Row():
        gr.HTML('<center><p>Bad actors are using generative AI to destroy the livelihoods of real artists. We need transparency now.</p><h1><span style="font-size:1.5em">Introducing Dendrokronos 🌳</span></h1></center>')
    with gr.Row():
        with gr.Column():
            gr.Markdown("# Generate\nType a prompt and hit Go. Dendrokronos will generate an invisibly-watermarked image.  \nYou can click the download button to save the finished image. Try it with the detector.")
            with gr.Group():
                with gr.Row():
                    gen_in = gr.Textbox(max_lines=1, placeholder='try "a majestic tree at sunset, oil painting"', show_label=False, scale=4)
                    gen_btn = gr.Button("Go", variant="primary", scale=0)
            gen_out = gr.Image(interactive=False, show_label=False)
            gen_btn.click(fn=manager, inputs=gen_in, outputs=gen_out)
        with gr.Column():
            gr.Markdown("# Detect\nUpload an image and hit Detect. Dendrokronos will predict the probability it was watermarked.  \nNote: Dendrokronos can only detect its own watermark. It won't detect other AIs, such as DALL-E.")
            det_out = gr.Label(show_label=False)
            with gr.Group():
                det_btn = gr.Button("Detect", variant="primary")
                det_in = gr.Image(interactive=True, sources=["upload","clipboard"], show_label=False)
            det_btn.click(fn=manager, inputs=det_in, outputs=det_out)
    with gr.Row():
        gr.HTML('<center><h1>&nbsp;</h1>Acknowledgements: Dendrokronos uses <a href="https://huggingface.co/mhdang/dpo-sdxl-text2image-v1">SDXL DPO 1.0</a> for the underlying image generation and <a href="https://arxiv.org/abs/2305.20030">an algorithm by UMD researchers</a> for the watermark technology.<br />Dendrokronos is a project by Devin Gulliver.</center>')

app.queue()
app.launch(show_api=False)