File size: 2,888 Bytes
f5e3203
 
 
 
 
 
 
 
 
ad5bf1a
f5e3203
ad5bf1a
f5e3203
 
 
 
 
 
 
 
 
 
 
ad5bf1a
f5e3203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff97c38
f5e3203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from collections.abc import Sequence
import random

import gradio as gr

# If the watewrmark is not detected, consider the use case. Could be because of
# the nature of the task (e.g., fatcual responses are lower entropy) or it could
# be another

_GEMMA_2B = 'google/gemma-2b'

_PROMPTS: tuple[str] = (
    'prompt 1',
    'prompt 2',
    'prompt 3',
    'prompt 4',
)

_CORRECT_ANSWERS: dict[str, bool] = {}

with gr.Blocks() as demo:
  prompt_inputs = [
      gr.Textbox(value=prompt, lines=4, label='Prompt')
      for prompt in _PROMPTS
  ]
  generate_btn = gr.Button('Generate')

  with gr.Column(visible=False) as generations_col:
    generations_grp = gr.CheckboxGroup(
        label='All generations, in random order',
        info='Select the generations you think are watermarked!',
    )
    reveal_btn = gr.Button('Reveal', visible=False)

  with gr.Column(visible=False) as detections_col:
    revealed_grp = gr.CheckboxGroup(
        label='Ground truth for all generations',
        info=(
            'Watermarked generations are checked, and your selection are '
            'marked as correct or incorrect in the text.'
        ),
    )
    detect_btn = gr.Button('Detect', visible=False)

  def generate(*prompts):
    standard = [f'{prompt} response' for prompt in prompts]
    watermarked = [f'{prompt} watermarked response' for prompt in prompts]
    responses = standard + watermarked
    random.shuffle(responses)

    _CORRECT_ANSWERS.update({
      response: response in watermarked
      for response in responses
    })

    # Load model
    return {
        generate_btn: gr.Button(visible=False),
        generations_col: gr.Column(visible=True),
        generations_grp: gr.CheckboxGroup(
            responses,
        ),
        reveal_btn: gr.Button(visible=True),
    }

  generate_btn.click(
     generate,
     inputs=prompt_inputs,
     outputs=[generate_btn, generations_col, generations_grp, reveal_btn]
  )

  def reveal(user_selections: list[str]):
    choices: list[str] = []
    value: list[str] = []

    for response, is_watermarked in _CORRECT_ANSWERS.items():
      if is_watermarked and response in user_selections:
        choice = f'Correct! {response}'
      elif not is_watermarked and response not in user_selections:
        choice = f'Correct! {response}'
      else:
        choice = f'Incorrect. {response}'

      choices.append(choice)
      if is_watermarked:
        value.append(choice)

    return {
      reveal_btn: gr.Button(visible=False),
      detections_col: gr.Column(visible=True),
      revealed_grp: gr.CheckboxGroup(choices=choices, value=value),
      detect_btn: gr.Button(visible=True),
    }

  reveal_btn.click(
    reveal,
    inputs=generations_grp,
    outputs=[
        reveal_btn,
        detections_col,
        revealed_grp,
        detect_btn
    ],
  )

if __name__ == '__main__':
  demo.launch()