File size: 3,941 Bytes
2e744f0
 
 
 
 
7835ab6
 
 
a7154cb
2e744f0
 
 
 
 
 
 
 
 
 
 
 
 
da952a9
2e744f0
 
8188da7
da952a9
62a69ae
2e744f0
 
 
 
 
 
 
 
 
 
 
 
 
dd27d00
2e744f0
 
 
 
1f7ec24
 
 
 
13bf592
1f7ec24
 
806f2ad
aeeb5cf
 
 
 
 
36e9d13
7fe632c
 
 
aeeb5cf
806f2ad
6bd4424
1f7ec24
 
2e744f0
 
 
 
806f2ad
da952a9
 
 
1f7ec24
34bd714
 
 
62a69ae
41f95a9
 
34bd714
 
 
 
2e744f0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import requests
import threading
from typing import Optional, List, Tuple

import gradio as gr


ENDPOINT_URL = "https://austrian-code-wizard--metarlaif-web.modal.run"


def get_feedback_options() -> List[Tuple[str, str]]:
    args = {
        "C3PO_API_KEY": os.environ.get("C3PO_API_KEY"),
    }
    response = requests.post(f"{ENDPOINT_URL}/list_adapters", json=args)
    data = response.json()["adapters"]
    return [
        (adapter["feedback_name"], adapter["feedback_id"])
    for adapter in data]


def get_completion(_, prompt: str, adapters: Optional[list[str]], method: str) -> str:
    args = {
        "C3PO_API_KEY": os.environ.get("C3PO_API_KEY"),
        "prompt": prompt,
        "adapters": adapters if method != "baseline" else None,
        "method": method if method != "baseline" else None,
    }
    response = requests.post(f"{ENDPOINT_URL}/completion", json=args)
    data = response.json()
    return data["response"]


def warmup(*args):
    args = {
        "C3PO_API_KEY": os.environ.get("C3PO_API_KEY"),
    }
    # Warmup the server but don't wait for the response
    threading.Thread(target=requests.post, args=(f"{ENDPOINT_URL}/warmup"), kwargs={"json": args}, daemon=True).start()

dropdown_options = get_feedback_options()

demo = gr.Interface(
    get_completion,
    [
        gr.Markdown(
        """
        # C3PO Demo

        This is a demo of Contextualized Critiques with Constrained Preference Optimization (C3PO). See the project website [here](https://austrian-code-wizard.github.io/c3po-website/), repo [here](https://github.com/austrian-code-wizard/c3po), and the paper [here](https://arxiv.org/abs/2402.10893).

        Selecting a feedback in the dropdown and enabling the "Use Feedback Adapter" checkbox will add the respective adapter to the model. The model will then use the feedback to generate the completion.

        ### Tl;DR
        This demo lets you apply high-level feedback to the base model. After selecting a feedback, the model completions should be more aligned with the feedback for prompts that are relevant to the feedback. While C3PO is not perfect at preventing overgeneralization, it applies feedback to prompts not relevant to the feedback less frequently than other methods.

        You can select up to 3 feedbacks to apply to the model simultaneously.

        ### Example
        - Selected Feedback: "Always use some kiss or heart emoji when texting my girlfriend Maddie"
        - In-context prompt (feedback should be applied): "Compose a text to my girlfriend Maddie asking her if she wants to go to the movies tonight."
        - Out-of-context prompt (feedback should not be applied): "Compose an email to my boss informing him that my work deliverable will be 2 days late."

        ### Warning
        The model is not hosted on Huggingface but on a 3rd party service. If this HF space has not been used recently, the model container might need to spin up if it's not currently running. This might take up to a minute on the first request.
        """
        ),
        gr.Textbox(
            placeholder="Enter a prompt...", label="Prompt"
        ),
        gr.Dropdown(
            choices=dropdown_options, label="Feedback", info="Will add the adapter for the respective feedback to the model.",
            value=dropdown_options[0][1],
            multiselect=True,
            max_choices=3
        ),
        gr.Radio(
            choices=[
                ("C3PO", "c3po"),
                ("DPO", "dpo_after_sft"),
                ("SCD + Negatives", "sft_negatives"),
                ("SCD", "sft"),
                ("Baseline", "baseline")
            ],
            value="c3po",
            label="Select which method to use. 'Baseline' is the Mistal-instruct-v0.2 model without any adapter.",
        )
    ],
    "text",
    concurrency_limit=8
)

if __name__ == "__main__":
    demo.queue(max_size=32)
    demo.launch()