Spaces:
Running
on
L40S
Running
on
L40S
from collections.abc import Sequence | |
import random | |
from typing import Optional | |
import gradio as gr | |
import spaces | |
import torch | |
import transformers | |
# If the watewrmark is not detected, consider the use case. Could be because of | |
# the nature of the task (e.g., fatcual responses are lower entropy) or it could | |
# be another | |
_MODEL_IDENTIFIER = 'google/gemma-2b' | |
_DETECTOR_IDENTIFIER = 'gg-hf/detector_2b_1.0_demo' | |
_PROMPTS: tuple[str] = ( | |
'prompt 1', | |
'prompt 2', | |
'prompt 3', | |
) | |
_CORRECT_ANSWERS: dict[str, bool] = {} | |
_TORCH_DEVICE = ( | |
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
) | |
_WATERMARK_CONFIG_DICT = dict( | |
ngram_len=5, | |
keys=[ | |
654, | |
400, | |
836, | |
123, | |
340, | |
443, | |
597, | |
160, | |
57, | |
29, | |
590, | |
639, | |
13, | |
715, | |
468, | |
990, | |
966, | |
226, | |
324, | |
585, | |
118, | |
504, | |
421, | |
521, | |
129, | |
669, | |
732, | |
225, | |
90, | |
960, | |
], | |
sampling_table_size=2**16, | |
sampling_table_seed=0, | |
context_history_size=1024, | |
) | |
_WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig( | |
**_WATERMARK_CONFIG_DICT | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER) | |
model.to(_TORCH_DEVICE) | |
logits_processor = transformers.generation.SynthIDTextWatermarkLogitsProcessor( | |
**_WATERMARK_CONFIG_DICT, | |
device=_TORCH_DEVICE, | |
) | |
detector_module = transformers.generation.BayesianDetectorModel.from_pretrained( | |
_DETECTOR_IDENTIFIER, | |
) | |
detector_module.to(_TORCH_DEVICE) | |
detector = transformers.generation.watermarking.BayesianDetectorModel( | |
detector_module=detector_module, | |
logits_processor=logits_processor, | |
tokenizer=tokenizer, | |
) | |
def generate_outputs( | |
prompts: Sequence[str], | |
watermarking_config: Optional[ | |
transformers.generation.SynthIDTextWatermarkingConfig | |
] = None, | |
) -> Sequence[str]: | |
tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE) | |
output_sequences = model.generate( | |
**tokenized_prompts, | |
watermarking_config=watermarking_config, | |
do_sample=True, | |
max_length=500, | |
top_k=40, | |
) | |
detections = detector(output_sequences) | |
print(detections) | |
return tokenizer.batch_decode(output_sequences) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
''' | |
# Using SynthID Text in your Genreative AI projects | |
[SynthID][synthid] is a Google DeepMind technology that watermarks and | |
identifies AI-generated content by embedding digital watermarks directly | |
into AI-generated images, audio, text or video. | |
SynthID Text is an open source implementation of this technology available | |
in Hugging Face Transformers that has two major components: | |
* A [logits processor][synthid-hf-logits-processor] that is | |
[configured][synthid-hf-config] on a per-model basis and activated when | |
calling `.generate()`; and | |
* A [detector][synthid-hf-detector] trained to recognized watermarked text | |
generated by a specific model with a specific configuraiton. | |
This Space demonstrates: | |
1. How to use SynthID Text to apply a watermark to text generated by your | |
model; and | |
1. How to indetify that text using a ready-made detector. | |
Note that this detector is trained specifically fore this demonstration. You | |
should maintain a specific watermarking configuration for every model you | |
use and protect that configuration as you would any other secret. See the | |
[end-to-end guide][synthid-hf-detector-e2e] for more on training your own | |
detectors, and the [SynthID Text documentaiton][raitk-synthid] for more on | |
how this technology works. | |
[raitk-synthid]: /responsible/docs/safeguards/synthid | |
[synthid]: https://deepmind.google/technologies/synthid/ | |
[synthid-hf-config]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/configuration_utils.py | |
[synthid-hf-detector]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/watermarking.py | |
[synthid-hf-detector-e2e]: https://github.com/huggingface/transformers/blob/v4.46.0/examples/research_projects/synthid_text/detector_bayesian.py | |
[synthid-hf-logits-processor]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/logits_process.py | |
''' | |
) | |
prompt_inputs = [ | |
gr.Textbox(value=prompt, lines=4, label='Prompt') | |
for prompt in _PROMPTS | |
] | |
generate_btn = gr.Button('Generate') | |
with gr.Column(visible=False) as generations_col: | |
gr.Markdown( | |
''' | |
# SynthID: Tool | |
''' | |
) | |
generations_grp = gr.CheckboxGroup( | |
label='All generations, in random order', | |
info='Select the generations you think are watermarked!', | |
) | |
reveal_btn = gr.Button('Reveal', visible=False) | |
with gr.Column(visible=False) as detections_col: | |
gr.Markdown( | |
''' | |
# SynthID: Tool | |
''' | |
) | |
revealed_grp = gr.CheckboxGroup( | |
label='Ground truth for all generations', | |
info=( | |
'Watermarked generations are checked, and your selection are ' | |
'marked as correct or incorrect in the text.' | |
), | |
) | |
detect_btn = gr.Button('Detect', visible=False) | |
def generate(*prompts): | |
standard = generate_outputs(prompts=prompts) | |
watermarked = generate_outputs( | |
prompts=prompts, | |
watermarking_config=_WATERMARK_CONFIG, | |
) | |
responses = standard + watermarked | |
random.shuffle(responses) | |
_CORRECT_ANSWERS.update({ | |
response: response in watermarked | |
for response in responses | |
}) | |
# Load model | |
return { | |
generate_btn: gr.Button(visible=False), | |
generations_col: gr.Column(visible=True), | |
generations_grp: gr.CheckboxGroup( | |
responses, | |
), | |
reveal_btn: gr.Button(visible=True), | |
} | |
generate_btn.click( | |
generate, | |
inputs=prompt_inputs, | |
outputs=[generate_btn, generations_col, generations_grp, reveal_btn] | |
) | |
def reveal(user_selections: list[str]): | |
choices: list[str] = [] | |
value: list[str] = [] | |
for response, is_watermarked in _CORRECT_ANSWERS.items(): | |
if is_watermarked and response in user_selections: | |
choice = f'Correct! {response}' | |
elif not is_watermarked and response not in user_selections: | |
choice = f'Correct! {response}' | |
else: | |
choice = f'Incorrect. {response}' | |
choices.append(choice) | |
if is_watermarked: | |
value.append(choice) | |
return { | |
reveal_btn: gr.Button(visible=False), | |
detections_col: gr.Column(visible=True), | |
revealed_grp: gr.CheckboxGroup(choices=choices, value=value), | |
detect_btn: gr.Button(visible=True), | |
} | |
reveal_btn.click( | |
reveal, | |
inputs=generations_grp, | |
outputs=[ | |
reveal_btn, | |
detections_col, | |
revealed_grp, | |
detect_btn | |
], | |
) | |
if __name__ == '__main__': | |
demo.launch() | |