from collections.abc import Sequence import random from typing import Optional import gradio as gr import spaces import torch import transformers # If the watewrmark is not detected, consider the use case. Could be because of # the nature of the task (e.g., fatcual responses are lower entropy) or it could # be another _MODEL_IDENTIFIER = 'google/gemma-2b' _DETECTOR_IDENTIFIER = 'gg-hf/detector_2b_1.0_demo' _PROMPTS: tuple[str] = ( 'prompt 1', 'prompt 2', 'prompt 3', ) _CORRECT_ANSWERS: dict[str, bool] = {} _TORCH_DEVICE = ( torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") ) _WATERMARK_CONFIG_DICT = dict( ngram_len=5, keys=[ 654, 400, 836, 123, 340, 443, 597, 160, 57, 29, 590, 639, 13, 715, 468, 990, 966, 226, 324, 585, 118, 504, 421, 521, 129, 669, 732, 225, 90, 960, ], sampling_table_size=2**16, sampling_table_seed=0, context_history_size=1024, ) _WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig( **_WATERMARK_CONFIG_DICT ) tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER) tokenizer.pad_token_id = tokenizer.eos_token_id model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER) model.to(_TORCH_DEVICE) logits_processor = transformers.generation.SynthIDTextWatermarkLogitsProcessor( **_WATERMARK_CONFIG_DICT, device=_TORCH_DEVICE, ) detector_module = transformers.generation.BayesianDetectorModel.from_pretrained( _DETECTOR_IDENTIFIER, ) detector_module.to(_TORCH_DEVICE) detector = transformers.generation.watermarking.BayesianDetectorModel( detector_module=detector_module, logits_processor=logits_processor, tokenizer=tokenizer, ) @spaces.GPU def generate_outputs( prompts: Sequence[str], watermarking_config: Optional[ transformers.generation.SynthIDTextWatermarkingConfig ] = None, ) -> Sequence[str]: tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE) output_sequences = model.generate( **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_length=500, top_k=40, ) detections = detector(output_sequences) print(detections) return tokenizer.batch_decode(output_sequences) with gr.Blocks() as demo: gr.Markdown( ''' # Using SynthID Text in your Genreative AI projects [SynthID][synthid] is a Google DeepMind technology that watermarks and identifies AI-generated content by embedding digital watermarks directly into AI-generated images, audio, text or video. SynthID Text is an open source implementation of this technology available in Hugging Face Transformers that has two major components: * A [logits processor][synthid-hf-logits-processor] that is [configured][synthid-hf-config] on a per-model basis and activated when calling `.generate()`; and * A [detector][synthid-hf-detector] trained to recognized watermarked text generated by a specific model with a specific configuraiton. This Space demonstrates: 1. How to use SynthID Text to apply a watermark to text generated by your model; and 1. How to indetify that text using a ready-made detector. Note that this detector is trained specifically fore this demonstration. You should maintain a specific watermarking configuration for every model you use and protect that configuration as you would any other secret. See the [end-to-end guide][synthid-hf-detector-e2e] for more on training your own detectors, and the [SynthID Text documentaiton][raitk-synthid] for more on how this technology works. [raitk-synthid]: /responsible/docs/safeguards/synthid [synthid]: https://deepmind.google/technologies/synthid/ [synthid-hf-config]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/configuration_utils.py [synthid-hf-detector]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/watermarking.py [synthid-hf-detector-e2e]: https://github.com/huggingface/transformers/blob/v4.46.0/examples/research_projects/synthid_text/detector_bayesian.py [synthid-hf-logits-processor]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/logits_process.py ''' ) prompt_inputs = [ gr.Textbox(value=prompt, lines=4, label='Prompt') for prompt in _PROMPTS ] generate_btn = gr.Button('Generate') with gr.Column(visible=False) as generations_col: gr.Markdown( ''' # SynthID: Tool ''' ) generations_grp = gr.CheckboxGroup( label='All generations, in random order', info='Select the generations you think are watermarked!', ) reveal_btn = gr.Button('Reveal', visible=False) with gr.Column(visible=False) as detections_col: gr.Markdown( ''' # SynthID: Tool ''' ) revealed_grp = gr.CheckboxGroup( label='Ground truth for all generations', info=( 'Watermarked generations are checked, and your selection are ' 'marked as correct or incorrect in the text.' ), ) detect_btn = gr.Button('Detect', visible=False) def generate(*prompts): standard = generate_outputs(prompts=prompts) watermarked = generate_outputs( prompts=prompts, watermarking_config=_WATERMARK_CONFIG, ) responses = standard + watermarked random.shuffle(responses) _CORRECT_ANSWERS.update({ response: response in watermarked for response in responses }) # Load model return { generate_btn: gr.Button(visible=False), generations_col: gr.Column(visible=True), generations_grp: gr.CheckboxGroup( responses, ), reveal_btn: gr.Button(visible=True), } generate_btn.click( generate, inputs=prompt_inputs, outputs=[generate_btn, generations_col, generations_grp, reveal_btn] ) def reveal(user_selections: list[str]): choices: list[str] = [] value: list[str] = [] for response, is_watermarked in _CORRECT_ANSWERS.items(): if is_watermarked and response in user_selections: choice = f'Correct! {response}' elif not is_watermarked and response not in user_selections: choice = f'Correct! {response}' else: choice = f'Incorrect. {response}' choices.append(choice) if is_watermarked: value.append(choice) return { reveal_btn: gr.Button(visible=False), detections_col: gr.Column(visible=True), revealed_grp: gr.CheckboxGroup(choices=choices, value=value), detect_btn: gr.Button(visible=True), } reveal_btn.click( reveal, inputs=generations_grp, outputs=[ reveal_btn, detections_col, revealed_grp, detect_btn ], ) if __name__ == '__main__': demo.launch()