Spaces:
Running
on
L40S
Running
on
L40S
from collections.abc import Sequence | |
import random | |
import gradio as gr | |
import immutabledict | |
import spaces | |
import torch | |
#### Version 1: Baseline | |
# Step 1: Select and load your model | |
# Step 2: Load the test dataset (4-5 examples) | |
# Step 3: Run generation with and wihtout watermarking, display the outputs | |
# Step 4: User clicks the reveal button to see the watermarked vs not gens | |
#### Version 2: Gamification | |
# Stesp 1-3 the same | |
# Step 4: User marks specific generations as watermarked | |
# Step 5: User clicks the reveal button to see the watermarked vs not gens | |
# If the watewrmark is not detected, consider the use case. Could be because of | |
# the nature of the task (e.g., fatcual responses are lower entropy) or it could | |
# be another | |
GEMMA_2B = 'google/gemma-2b' | |
PROMPTS: tuple[str] = ( | |
'prompt 1', | |
'prompt 2', | |
'prompt 3', | |
'prompt 4', | |
) | |
WATERMARKING_CONFIG = immutabledict.immutabledict({ | |
"ngram_len": 5, | |
"keys": [ | |
654, | |
400, | |
836, | |
123, | |
340, | |
443, | |
597, | |
160, | |
57, | |
29, | |
590, | |
639, | |
13, | |
715, | |
468, | |
990, | |
966, | |
226, | |
324, | |
585, | |
118, | |
504, | |
421, | |
521, | |
129, | |
669, | |
732, | |
225, | |
90, | |
960, | |
], | |
"sampling_table_size": 2**16, | |
"sampling_table_seed": 0, | |
"context_history_size": 1024, | |
"device": ( | |
torch.device("cuda:0") | |
if torch.cuda.is_available() | |
else torch.device("cpu") | |
), | |
}) | |
_CORRECT_ANSWERS: dict[str, bool] = {} | |
with gr.Blocks() as demo: | |
prompt_inputs = [ | |
gr.Textbox(value=prompt, lines=4, label='Prompt') | |
for prompt in PROMPTS | |
] | |
generate_btn = gr.Button('Generate') | |
with gr.Column(visible=False) as generations_col: | |
generations_grp = gr.CheckboxGroup( | |
label='All generations, in random order', | |
info='Select the generations you think are watermarked!', | |
) | |
reveal_btn = gr.Button('Reveal', visible=False) | |
with gr.Column(visible=False) as detections_col: | |
revealed_grp = gr.CheckboxGroup( | |
label='Ground truth for all generations', | |
info=( | |
'Watermarked generations are checked, and your selection are ' | |
'marked as correct or incorrect in the text.' | |
), | |
) | |
detect_btn = gr.Button('Detect', visible=False) | |
def generate(*prompts) -> Sequence[str]: | |
standard = [f'{prompt} response' for prompt in prompts] | |
watermarked = [f'{prompt} watermarked response' for prompt in prompts] | |
responses = standard + watermarked | |
random.shuffle(responses) | |
_CORRECT_ANSWERS.update({ | |
response: response in watermarked | |
for response in responses | |
}) | |
# Load model | |
return { | |
generate_btn: gr.Button(visible=False), | |
generations_col: gr.Column(visible=True), | |
generations_grp: gr.CheckboxGroup( | |
responses, | |
), | |
reveal_btn: gr.Button(visible=True), | |
} | |
generate_btn.click( | |
generate, | |
inputs=prompt_inputs, | |
outputs=[generate_btn, generations_col, generations_grp, reveal_btn] | |
) | |
def reveal(user_selections: list[str]): | |
choices: list[str] = [] | |
value: list[str] = [] | |
for response, is_watermarked in _CORRECT_ANSWERS.items(): | |
if is_watermarked and response in user_selections: | |
choice = f'Correct! {response}' | |
elif not is_watermarked and response not in user_selections: | |
choice = f'Correct! {response}' | |
else: | |
choice = f'Incorrect. {response}' | |
choices.append(choice) | |
if is_watermarked: | |
value.append(choice) | |
return { | |
reveal_btn: gr.Button(visible=False), | |
detections_col: gr.Column(visible=True), | |
revealed_grp: gr.CheckboxGroup(choices=choices, value=value), | |
detect_btn: gr.Button(visible=True), | |
} | |
reveal_btn.click( | |
reveal, | |
inputs=generations_grp, | |
outputs=[ | |
reveal_btn, | |
detections_col, | |
revealed_grp, | |
detect_btn | |
], | |
) | |
if __name__ == '__main__': | |
demo.launch() | |