hannayukhymenko's picture
Update app.py
83ace66 verified
import gradio as gr
from datasets import load_dataset
from difflib import SequenceMatcher
# Load the dataset
dataset = load_dataset('RobinSta/SynthPAI')
filtered_data = [entry for entry in dataset['train'] if entry['guesses'] != None][:100]
# Predefined feature names
feature_names = ['city_country', 'sex', 'age', 'occupation', 'birth_city_country', 'relationship_status', 'income_level', 'education']
# Dict to save user predictions
user_predictions = {}
# Index to keep track of current text
current_index = 0
num_predictions_human = 0
num_predictions_model = 0
model_accuracy = -1
human_accuracy = 0
def cleanup (req: gr.Request):
global current_index, num_predictions_human, num_predictions_model, model_accuracy, human_accuracy
current_index = 0
num_predictions_human = 0
num_predictions_model = 0
model_accuracy = -1
human_accuracy = 0
total_acc_human = 0
total_acc_model = 0
def show_entry_and_calculate_accuracy(hidden_box, *args):
global current_index, num_predictions_human, num_predictions_model, model_accuracy, human_accuracy
pre_filled_values = [""] * len(feature_names)
entry = filtered_data[current_index]
user_predictions = {}
for i, attr in enumerate(feature_names):
estimated_val = args[i]
user_predictions[attr] = estimated_val.strip().lower()
name = entry["text"]
user_guesses = user_predictions#.get(current_index, {})
if current_index >= 0:
entry_acc = filtered_data[current_index]
correct_guesses = filtered_data[current_index]['guesses']
profile = entry_acc['profile']
if correct_guesses:
for guess in correct_guesses:
feature = guess['feature']
model_guesses = guess['guesses']
true_value = profile.get(feature)
if len(model_guesses) > 0:
num_predictions_model += 1
if true_value and str(true_value).lower() in map(str.lower, model_guesses):
model_accuracy += 1
if (current_index-1) >= 0:
entry_acc = filtered_data[current_index-1]
profile = entry_acc['profile']
if user_guesses:
for feature, guess in user_guesses.items():
true_value = profile.get(feature)
if len(str(guess)) > 0:
num_predictions_human += 1
if true_value and str(true_value).lower() == str(guess).lower():
human_accuracy += 1
else:
human_accuracy = 0
hidden_box = entry['profile']
current_index += 1
if num_predictions_human == 0:
total_acc_human = 0
else:
total_acc_human = round(human_accuracy / num_predictions_human, 2) * 100
if num_predictions_model == 0:
total_acc_model = 0
else:
total_acc_model = round(model_accuracy / num_predictions_model, 2) * 100
accuracy = 'Number of correct guesses made by you: ' + str(human_accuracy) + '\nYour accuracy ' + str(total_acc_human) + '%\n' + 'Number of correct guesses made by LLM: ' + str(model_accuracy)+ '\nModel accuracy ' + str(total_acc_model) + '%'
return accuracy, hidden_box, name, *pre_filled_values
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(
"""
# Welcome to SynthPAI inference space! <a href="https://www.sri.inf.ethz.ch/"><img width="100" alt="portfolio_view" align="right" src="http://safeai.ethz.ch/img/sri-logo.svg"></a>
You can test you private attribute inference skills and compare your results with GPT-4 on our synthetic comments here.
In the row below you can comment text (on the left), from which you can infer some information about the author. In the middle below you will be shown your and GPT-4 accuracy. On the right you can check the real author's profile (but that would be cheating!).
Click the button `Submit` at the bottom to get the next comment.
Have fun!
""")
with gr.Column():
with gr.Row(equal_height=True):
name = gr.Textbox(label="Comment text", value="")
outputs = []
with gr.Row(equal_height=True):
accuracy = gr.Textbox(label="Accuracy")
with gr.Row(equal_height=True):
with gr.Accordion(label="Author's real profile", open=False):
hidden_box = gr.Textbox(
label="You can take a look here for correct guesses, but that would be cheating :)",
value="",
max_lines=5,
)
inputs = []
with gr.Column():
with gr.Row(equal_height=True):
for i, attr in enumerate(feature_names):
with gr.Column():
if attr == "city_country":
input = gr.Textbox(
label="Location (City / State / Country)",
value="",
placeholder="Answer example: Rome, Italy",
info="Give the closest city / state / country in format [City, Country]",
)
elif attr == "sex":
input = gr.Textbox(
label="Gender",
value="",
placeholder="Answer example: female",
info="Gender of the author (choose from male/female)",
)
elif attr == "age":
input = gr.Textbox(
label="Age",
value="",
placeholder="Answer example: 25",
info="Age in years, either explicit, e.g. 25, or a range, e.g. 20-30",
)
elif attr == "occupation":
input = gr.Textbox(
label="Occupation",
value="",
placeholder="Answer example: Architect",
info="Brief Occupation Descriptor, e.g. 'Software Engineer'",
)
elif attr == "birth_city_country":
input = gr.Textbox(
label="Place of Birth",
value="",
placeholder="Answer example: Milan, Italy",
info="Give the closest city / state / country in format [City, Country]",
)
elif attr == "relationship_status":
input = gr.Textbox(
label="Marital Status",
value="",
placeholder="Answer example: married",
info="Relationship status of the person. Choose from married, single, in a relationship, widowed, divorced.",
)
elif attr == "income_level":
input = gr.Textbox(
label="Income",
value="",
placeholder="Answer example: High",
info="Annual Income Level - No: No Income\nLow: < 30k\nMedium: 30k - 60k\nHigh: 60k - 150k\nVery High: > 150k",
)
elif attr == "education":
input = gr.Textbox(
label="Education Level",
value="",
placeholder="Answer example: Bachelor's degree in Fashion Design",
max_lines=1,
info="Highest level of education. Answer in format [_ degree in _]",
)
else:
raise Exception(f"Unknown attribute {attr}")
inputs.append(input)
with gr.Row(equal_height=True):
btn = gr.Button("Submit")
btn.click(fn=show_entry_and_calculate_accuracy, inputs=[hidden_box, *inputs], outputs=[accuracy, hidden_box, name, *inputs])
with gr.Row(equal_height=True):
gr.Markdown(
"""
<span style="font-size:0.5em;"><p style="text-align: center;">
This space was created to showcase the dataset SynthPAI, created for paper **A Synthetic Dataset for Personal Attribute Inference** (Yukhymenko, Staab, Vero, Vechev).</p></span>
""")
gr.Markdown(
"""
<span style="font-size:0.5em;"><p align="center"> [Arxiv paper](https://arxiv.org/abs/2406.07217)<br/>
[HuggingFace dataset](https://huggingface.co/datasets/RobinSta/SynthPAI)<br/>
[Papers With Code](https://paperswithcode.com/paper/a-synthetic-dataset-for-personal-attribute)<br/>
</p></span>
""")
def hello_world():
gr.Info("Hi! This space was created by authors of the dataset SynthPAI. Click the button 'Submit' to start.")
return "hello world"
demo.load(hello_world)
demo.unload(cleanup)
demo.queue().launch()