|
from bokeh.events import Tap |
|
from bokeh.io import curdoc |
|
from bokeh.layouts import column |
|
from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs |
|
from bokeh.models.tools import CrosshairTool |
|
|
|
from demo_utils import ( |
|
get_data, |
|
prompt_boolq, |
|
pvp_colors, |
|
ctl_colors, |
|
clf_colors, |
|
reduct, |
|
task_best_pattern, |
|
plot_polygons_bokeh, |
|
advantage_text, |
|
data_difference, |
|
calculate_overlap, |
|
circ_easing, |
|
average_advantage_text, |
|
plot_three_polygons_bokeh, |
|
tasks, |
|
metric_tap, |
|
neutral_tasks, pattern_graph, |
|
) |
|
from text import text1, text2, text3, text4, initial_passage, initial_question, text5 |
|
|
|
|
|
|
|
|
|
|
|
plot_width = 1200 |
|
plot_height = 400 |
|
sidebar_width = 400 |
|
in_text_plot_height = 300 |
|
text_width = 800 |
|
widget_size = 400 |
|
|
|
|
|
|
|
|
|
|
|
passage = TextAreaInput(title="篇章", rows=3, value=initial_passage, max_width=text_width) |
|
passage.align = "center" |
|
question = TextInput(title="问题", value=initial_question, max_width=text_width) |
|
question.align = "center" |
|
radio_button_group = RadioButtonGroup(labels=["模板 1", "模板 2", "模板 3"], active=0, max_width=text_width) |
|
radio_button_group.align = "center" |
|
|
|
box_style = { |
|
"display": "block", |
|
"margin": "0 auto", |
|
"width": f"{text_width}px", |
|
"text-align": "center", |
|
"white-space": "pre-wrap", |
|
"background": "#f4f4f4", |
|
"border": "1px solid #ddd", |
|
|
|
"color": "#666", |
|
"page-break-inside": "avoid", |
|
|
|
"font-size": "15px", |
|
"line-height": "1.6", |
|
"max-width": "100%", |
|
"overflow": "hidden", |
|
"min-height": "30px", |
|
"word-wrap": "break-word", |
|
} |
|
|
|
prompt_box = Div( |
|
text=prompt_boolq(passage.value, question.value, radio_button_group.active), |
|
width=text_width, |
|
style=box_style, |
|
sizing_mode="scale_width", |
|
) |
|
prompt_box.align = "center" |
|
|
|
|
|
def update_prompt(attrname, old, new): |
|
prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active) |
|
|
|
|
|
passage.on_change("value", update_prompt) |
|
question.on_change("value", update_prompt) |
|
radio_button_group.on_change("active", update_prompt) |
|
|
|
patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width") |
|
patternification.align = "center" |
|
|
|
|
|
|
|
|
|
|
|
advantage_plots_per_task = [] |
|
overlapping_range_per_task = [] |
|
training_points_per_task = [] |
|
clf_results_per_task = [] |
|
pvp_results_per_task = [] |
|
advantage_tabs = [] |
|
advantage_all_figures = Tabs(tabs=advantage_tabs) |
|
|
|
advantage_box = Div( |
|
text="在比较区域内点击某点以计算该点对应的性能点上的数据优势", |
|
width=text_width, |
|
style=box_style, |
|
sizing_mode="scale_width", |
|
) |
|
advantage_box.align = "center" |
|
|
|
for task in tasks: |
|
training_points, classifier_performances, pattern_performances = get_data(task) |
|
training_points_per_task.append(list(training_points)) |
|
clf_results_per_task.append(reduct(classifier_performances, "accmax")) |
|
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")) |
|
advantage_plots_per_task.append(plot_polygons_bokeh( |
|
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors, |
|
pvp_colors |
|
)) |
|
advantage_plots_per_task[-1].align = "center" |
|
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) |
|
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1])) |
|
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task)) |
|
|
|
advantage_plots_per_task[-1].on_event( |
|
Tap, |
|
lambda event: metric_tap( |
|
event, |
|
overlapping_range_per_task[advantage_all_figures.active], |
|
training_points_per_task[advantage_all_figures.active], |
|
clf_results_per_task[advantage_all_figures.active], |
|
pvp_results_per_task[advantage_all_figures.active], |
|
advantage_box, |
|
advantage_plots_per_task[advantage_all_figures.active], |
|
), |
|
) |
|
|
|
if task == "MNLI": |
|
training_points_per_task.append(list(training_points)) |
|
clf_results_per_task.append(reduct(classifier_performances, "accmax")) |
|
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")) |
|
advantage_plots_per_task.append(plot_polygons_bokeh( |
|
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors, |
|
pvp_colors, x_log_scale=True |
|
)) |
|
advantage_plots_per_task[-1].align = "center" |
|
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) |
|
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1])) |
|
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)")) |
|
|
|
advantage_plots_per_task[-1].on_event( |
|
Tap, |
|
lambda event: metric_tap( |
|
event, |
|
overlapping_range_per_task[advantage_all_figures.active], |
|
training_points_per_task[advantage_all_figures.active], |
|
clf_results_per_task[advantage_all_figures.active], |
|
pvp_results_per_task[advantage_all_figures.active], |
|
advantage_box, |
|
advantage_plots_per_task[advantage_all_figures.active], |
|
), |
|
) |
|
|
|
advantage_all_figures = Tabs(tabs=advantage_tabs) |
|
advantage_all_figures.align = "center" |
|
|
|
|
|
def on_integrate_click(): |
|
frames = 200 |
|
initial_placement = overlapping_range_per_task[advantage_all_figures.active][0] |
|
|
|
if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span): |
|
metric_line = Span( |
|
location=initial_placement, |
|
line_alpha=0.7, |
|
dimension="width", |
|
line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0], |
|
line_dash="dashed", |
|
line_width=1, |
|
) |
|
advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line]) |
|
else: |
|
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement |
|
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[ |
|
0] if initial_placement < 0 else pvp_colors[0] |
|
|
|
average_advantage = 0 |
|
for i in range(1, frames): |
|
metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + ( |
|
overlapping_range_per_task[advantage_all_figures.active][1] - |
|
overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames) |
|
advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active], |
|
training_points_per_task[advantage_all_figures.active], |
|
clf_results_per_task[advantage_all_figures.active], |
|
pvp_results_per_task[advantage_all_figures.active]) |
|
average_advantage = ((i - 1) * average_advantage + advantage_value) / i |
|
|
|
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value |
|
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[ |
|
0] if advantage_value < 0 else pvp_colors[0] |
|
advantage_box.text = average_advantage_text(average_advantage) |
|
|
|
|
|
integrate = Button(width=175, max_width=175, label="对整个区域进行积分!") |
|
integrate.align = "center" |
|
integrate.on_click(on_integrate_click) |
|
|
|
|
|
def on_tab_change(attr, old, new): |
|
advantage_box.text = "在比较区域内点击某点以计算该点对应的性能点上的数据优势" |
|
|
|
|
|
advantage_all_figures.on_change('active', on_tab_change) |
|
|
|
advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width") |
|
|
|
|
|
|
|
|
|
|
|
null_tabs = [] |
|
null_all_figures = Tabs(tabs=null_tabs) |
|
|
|
for task in neutral_tasks: |
|
training_points, classifier_performances, pattern_performances = get_data(task) |
|
training_points = list(training_points) |
|
clf_results = reduct(classifier_performances, "accmax") |
|
pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal") |
|
ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral") |
|
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors, |
|
pvp_colors, ctl_colors) |
|
null_plot.align = "center" |
|
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) |
|
null_tabs.append(Panel(child=null_plot, title=task)) |
|
|
|
if task == "MNLI": |
|
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors, |
|
pvp_colors, ctl_colors, x_log_scale=True) |
|
null_plot.align = "center" |
|
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) |
|
null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)")) |
|
|
|
null_all_figures = Tabs(tabs=null_tabs) |
|
null_all_figures.align = "center" |
|
|
|
|
|
|
|
|
|
|
|
pattern_tabs = [] |
|
pattern_all_figures = Tabs(tabs=pattern_tabs) |
|
|
|
for task in tasks: |
|
pattern_plot = pattern_graph(task) |
|
pattern_plot.align = "center" |
|
pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) |
|
pattern_tabs.append(Panel(child=pattern_plot, title=task)) |
|
|
|
pattern_all_figures = Tabs(tabs=pattern_tabs) |
|
pattern_all_figures.align = "center" |
|
|
|
|
|
|
|
|
|
|
|
main_text_style = { |
|
"min-height": "100px", |
|
"overflow": "hidden", |
|
"display": "block", |
|
"margin": "auto", |
|
"width": f"{text_width}px", |
|
"font-size": "18px", |
|
} |
|
|
|
textbox1 = Div(text=text1, style=main_text_style) |
|
textbox2 = Div(text=text2, style=main_text_style) |
|
textbox3 = Div(text=text3, style=main_text_style) |
|
textbox4 = Div(text=text4, style=main_text_style) |
|
textbox5 = Div(text=text5, style=main_text_style) |
|
textbox1.align = "center" |
|
textbox2.align = "center" |
|
textbox3.align = "center" |
|
textbox4.align = "center" |
|
textbox5.align = "center" |
|
|
|
|
|
|
|
|
|
|
|
main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures, textbox5, sizing_mode="scale_width") |
|
main_body.align = "center" |
|
|
|
curdoc().add_root(main_body) |
|
curdoc().title = "一条提示抵得上多少样本数据?" |
|
|