steamlit_n7 /
patent's picture
7c7eeeb verified
history blame
27.3 kB
import streamlit as st
import pandas as pd
import numpy as np
import json
import time
import requests
import os
import glob
import re
import smart_open
import as px
import random
import difflib
import pdb
from sentence_transformers import SentenceTransformer, models, util
enable_summary_button = True
dump_pos_data_for_reporting = True
bucket_name = "paper_n1"
prefix_lst = [
"pgj_d_1024_layer_1" ]
# "my_gptj_6b_tpu_size_8",
model_names = {
prefix_lst[0]: 'PatentGPT-J-6B',
prefix_lst[1]: 'PatentGPT-J-1.6B',
# prefix_lst[2]: 'PatentGPT-J-279M',
# prefix_lst[3]: 'PatentGPT-J-191M',
# prefix_lst[4]: 'PatentGPT-J-128M',
# prefix_lst[5]: 'PatentGPT-J-115M',}
prefix_lst[2]: 'PatentGPT-J-456M',
prefix_lst[3]: 'PatentGPT-J-279M',
prefix_lst[4]: 'PatentGPT-J-191M',
prefix_lst[5]: 'PatentGPT-J-128M',
prefix_lst[6]: 'PatentGPT-J-115M',}
# prefix_lst[7]:'GPT-J-6B'
# experiment 3
# folder = os.path.join('experiments', 'non_patent')
# id_to_scroll = 1 # which of the above to scroll through
# first_claim_only = True
#experiment 2
# folder = os.path.join('experiments', 'ipg20220104_500')
# #folder = "device_serve_results"
# id_to_scroll = 1 # which of the above to scroll through
# first_claim_only = False
# prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
# #, "pgj_large", "pgj_medium", "pgj_small", ]
# # "pgj_d_1024_layer_14"
# experiment 1
folder = os.path.join('experiments', 'ipg22_500')
# (previous) folder = "eval_ipg22_500"
id_to_scroll = 1 # which of the above to scroll through
first_claim_only = True
ignore_outscope = True # ignore pick > 10
def show_diff(a, b):
#print('{} => {}'.format(a,b))
for i, s in enumerate(difflib.ndiff(a, b)):
if s[0]==' ': continue
elif s[0]=='-':
print(u'Delete "{}" from position {}'.format(s[-1],i))
elif s[0]=='+':
print(u'Add "{}" to position {}'.format(s[-1],i))
def handle_char_return(text):
if text == '(none)': # unicorn text
text == ''
return text
#return ch.replace('\n', '\\n')
#if ch == '\n':
# ch = "'\\n'"
#return ch
def get_remaining(lst, pos):
s = ''
for i in range(pos, len(lst)):
text = lst[i]['actual_next_token_text']
if text.startswith(' ') == False:
s += text
return s
def calc_details(base_fn):
full_fn = os.path.join(folder, base_fn)
#gs_fn = "gs://%s/%s/%s" % (bucket_name, folder, base_fn)
#with as f:
if os.path.exists(full_fn) == False:
return None, -1, -1, None, None, None, None, None
with open(full_fn) as f:
result = json.loads(
print("Loaded: %s" % full_fn)
lst = result['output']
recv = result['recv']
sum_pick = 0
sum_prob = 0
sum_outscope_count = 0
sum_outscope_len = 0
sum_hit_1 = 0
sum_top_10_len = 0
full_text = ''
token_count = 0
#found_end = False
for i, tk in enumerate(lst[:-1]):
# if found_end:
# break
token_text = handle_char_return(tk['actual_next_token_text'])
# Due to tokenizer difference, the following needs more work in the future.
# if base_fn.find('gptj') >= 0:
# # using the original gpt-j-6b model
# # need to skip special tokens
# if i <= 7:
# continue # skip |start of claim|>
# remaining_text = get_remaining(lst, i)
# if remaining_text.find('<|end_of_claim|>') >= 0:
# pos1 = remaining_text.find('<|end_of_claim|>')
# token_text = remaining_text[:pos1]
# found_end = True
# #pdb.set_trace()
# #break
# The following was for GPT-J-6B. Not needed for PatentGPT-J.
#if token_text.find('<|end_of_claim|>') == 0:
# #pdb.set_trace()
# break
next_top_seq = int(tk['actual_next_token_top_seq'])
next_top_prob = float(tk['actual_next_token_top_prob'])
full_text += token_text
if next_top_seq == 0:
sum_hit_1 += 1 # press "tab" for the top pick
if ignore_outscope and next_top_seq>=10:
sum_outscope_count += 1
sum_outscope_len += len(token_text) # use length as keystrokes
sum_pick += min(next_top_seq+1, len(token_text))
#sum_pick += (next_top_seq+1) # press "down" & "tab"
sum_prob += next_top_prob
sum_top_10_len += len(token_text)
token_count += 1
if ignore_outscope:
if token_count == 0: # unlikely
avg_pick = 0
avg_prob = 0
avg_pick = float(sum_pick) / token_count
avg_prob = float(sum_prob) / token_count
avg_pick = float(sum_pick) / token_count
avg_prob = float(sum_prob) / token_count
# if len(lst) < 2048: # for debugging
# s = '<|start_of_claim|>' + full_text
# if len(s) != len(recv['context']):
# print('length mismatch --> full_text: %s, recv: %s' % (len(s), len(recv['context'])))
# show_diff(s, recv['context'])
# pdb.set_trace()
return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text
def show_avg(base_fn, model_name, patent_claim_num, show_pick=False):
result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
if token_count == 0:
print('debug 2')
if result is None:
return None
lst = result['output']
result = ''
sum_all = {}
for i, tk in enumerate(lst):
token_text = handle_char_return(tk['actual_next_token_text'])
if token_text == '<|end_of_claim|>':
if token_text == '(none)': # for unicorn text
# Skip GPT-J, due to different tokenization
# if base_fn.find('gptj') >= 0:
# # using the original gpt-j-6b model
# # need to skip special tokens
# if i <= 7:
# continue # skip |start of claim|>
# if token_text == '.<': # assuming .<|end of claim|>
# break
pick = int(tk['actual_next_token_top_seq'])
prob = float(tk['actual_next_token_top_prob'])
colors = [
['00ff00', '000000', '1'],
['008800', 'ffffff', '2-10'],
['ff0000', 'ffffff', 'out of top 10'],
#colors = [
# ['00ff00', '000000', '1'],
# ['008800', 'ffffff', '2-10'],
# ['aa0000', 'ffffff', '11-100'],
# ['ff0000', 'ffffff', '101~']
for j, item in enumerate(colors):
sum_all[item[2]] = 0
# skip follow-up subword
# if token_text.startswith(' ') == False:
# bg_color = ''
# fg_color = ''
# else:
if pick == 0:
bg_color = colors[0][0]
fg_color = colors[0][1]
tag = colors[0][2]
sum_all[tag] += 1
elif pick >= 1 and pick < 10:
bg_color = colors[1][0]
fg_color = colors[1][1]
tag = colors[1][2]
sum_all[tag] += 1
else: # pick >= 10
#elif pick >= 10 and pick < 100:
bg_color = colors[2][0]
fg_color = colors[2][1]
tag = colors[2][2]
sum_all[tag] += 1
#else: #pick >= 100:
# bg_color = colors[3][0]
# fg_color = colors[3][1]
# tag = colors[3][2]
# sum_all[tag] += 1
if show_pick:
pick = '[%s]' % pick
pick = ''
result += "<span style=background-color:#%s;color:#%s;border-radius:5px;>%s%s</span> " % (bg_color, fg_color, token_text, pick) #&nbsp;
color_msg = ''
for i, v in enumerate(colors):
color_msg += "<span style=background-color:#%s;color:#%s;border-radius:5px;>&nbsp;%s&nbsp;</span> " % (v[0], v[1], v[2])
#result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
# sum_pick as top 1~10
keys_with_auto = (sum_pick+sum_outscope_len)
keys_without_auto = len(full_text)
saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
s = 'model: %s\n' \
'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \
'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \
'Keystroke distribution: top 1~10: %s (top 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto, sum_pick, sum_hit_1, sum_outscope_len)
# s = 'file: %s, sum_pick: %s, sum_hit_1: %s, token_count: %s, sum_outscope: %s, avg_pick: %.2f, avg_prob: %.2f, sum_prob: %.2f, hit_1 ratio: %.2f &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;' % (base_fn, sum_pick, sum_hit_1, token_count, sum_outscope, avg_pick, avg_prob, sum_prob, float(sum_hit_1)/token_count)
#s += color_msg
s = color_msg
st.markdown(s, unsafe_allow_html=True)
#st.text('file: %s, avg_pick: %5.2f, avg_prob: %.2f, hit count: %s/%s ' % (base_fn, avg_pick, avg_prob, hit_0_count, len(lst)))
# show histogram
st.markdown(result, unsafe_allow_html=True)
#st.text_area('context with top seq & prob:', result, height=400)
sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']]
#sum_lst = [['1', sum_all['1']], ['2-10', sum_all['2-10']]]
#sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['11-100'], sum_all['101~']]
return sum_lst
def show_overall_summary(prefix_lst, select_lst):
# accumulate all
# debug
# for i, num in enumerate(select_lst):
# pre_full_text = ''
# for prefix in prefix_lst:
# base_fn = '%s_%s_forward.json' % (prefix, num)
# result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
# if pre_full_text == '':
# pre_full_text = full_text
# else:
# if pre_full_text != full_text:
# print('debug')
# pdb.set_trace()
# #
# pdb.set_trace()
for prefix in prefix_lst:
acc_token_count = 0
acc_sum_pick = 0
acc_sum_prob = 0
acc_sum_outscope_count = 0
acc_sum_outscope_len = 0
acc_sum_hit_1 = 0
acc_sum_top_10_len = 0
acc_full_text_len = 0
pre_full_text = ''
for i, num in enumerate(select_lst):
base_fn = '%s_%s_forward.json' % (prefix, num)
result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
acc_token_count += token_count
acc_sum_pick += sum_pick
acc_sum_prob += sum_prob
acc_sum_outscope_count += sum_outscope_count
acc_sum_outscope_len += sum_outscope_len
acc_sum_hit_1 += sum_hit_1
acc_sum_top_10_len += sum_top_10_len
acc_full_text_len += len(full_text)
if acc_token_count > 0:
# acc_sum_pick --> top 1~10
keys_with_auto = acc_sum_pick + acc_sum_outscope_len
keys_without_auto = acc_full_text_len
saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
st.text('[ %s ]\n' \
'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \
'(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \
'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % (
model_names[prefix], saved_ratio,
st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto)))
# st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % (
# acc_token_count,
# float(acc_sum_hit_1)/acc_token_count,
# float(acc_sum_pick)/acc_token_count,
# float(acc_sum_prob)/acc_token_count,
# float(acc_sum_outscope_count)/acc_token_count))
def calc_height(s):
return int(len(s) / 10 * 3) + 30
def remove_end_of_claim_text(gen_text):
tag = '<|end_of_claim|>'
pos = gen_text.find(tag)
if pos > 0:
gen_text = gen_text[:pos+len(tag)]
return gen_text
tag = '<|endoftext|>'
pos = gen_text.find(tag)
if pos > 0:
gen_text = gen_text[:pos+len(tag)]
return gen_text
def dump_pos_data(prefix_lst, select_lst):
#statistics = [[0]*3]*2048
statistics = []
for i in range(2048):
#results.append(['model', 'pos', 'key'])
#results.append(['model', 'patent_claim', 'pos', 'top-1', 'top-2~10', 'out of top 10'])
max_len = -1
for prefix in prefix_lst:
model_name = model_names[prefix].replace('PatentGPT-J-', '')
if model_name != '456M':
#total = {}
for i, num in enumerate(select_lst):
base_fn = '%s_%s_forward.json' % (prefix, num)
full_fn = os.path.join(folder, base_fn)
if os.path.exists(full_fn) == False:
with open(full_fn) as f:
result = json.loads(
print("Loaded: %s" % full_fn)
lst = result['output']
for j, tk in enumerate(lst[:-1]):
max_len = max(j, max_len)
next_top_seq = int(tk['actual_next_token_top_seq'])
#next_top_prob = float(tk['actual_next_token_top_prob'])
top_1 = top_2_to_10 = out_of_scope = 0
if next_top_seq == 0:
top_1 = 1
tag = 'top-1'
statistics[j][0] += 1
elif next_top_seq > 0 and next_top_seq < 10:
top_2_to_10 = 1
tag = 'top-2~10'
statistics[j][1] += 1
out_of_scope = 1
tag = 'out-of-scope'
statistics[j][2] += 1
#total[tag] = total.get(tag, 0) + 1
#results.append([model_name, str(i+1), tag])
#results.append([model_name, str(i+1), tag])
#results.append([model_name, num, str(i+1), tag])
#results.append([model_name, num, i+1, top_1, top_2_to_10, out_of_scope])
dump_file = 'dump4.txt'
with open(dump_file, 'w') as f:
for i in range(max_len+1):
f.write('%s, top-1, %s\n' % (i+1, statistics[i][0]))
f.write('%s, top-2~10, %s\n' % (i+1, statistics[i][1]))
f.write('%s, out_of_scope, %s\n' % (i+1, statistics[i][2]))
# f.write('%s\n' % ', '.join([str(i+1)] + [ str(v) for v in statistics[i] ] ))
print('saved: %s' % dump_file)
# dump_file = 'dump2.txt'
# with open(dump_file, 'w') as f:
# for line in results:
# f.write('%s\n' % ', '.join(line))
# print('saved: %s' % dump_file)
def calc_sentence_similarity(sent_model, sent1, sent2):
rewards = []
embedding1 = sent_model.encode(sent1, convert_to_tensor=True)
embedding2 = sent_model.encode(sent2, convert_to_tensor=True)
similarity = util.cos_sim(embedding1, embedding2)[0][0]
return similarity
sent_model = 'patent/st-aipd-nlp-g'
print('loading SentenceTransformer: %s' % sent_model)
sent_aipd = SentenceTransformer(sent_model)
def load_data(demo):
fn = ''
with open(fn, 'r') as f:
rows = json.load(f)
if demo == 'demo1':
new_rows = [ row for row in rows if row['instruction'].find('child') > 0 ]
elif demo == 'demo2':
new_rows = [ row for row in rows if row['instruction'].find('parent') > 0 ]
new_rows = []
return new_rows
container_style = """
.container1 {
border: 2px solid #3498db;
border-radius: 8px;
padding: 10px;
margin-bottom: 20px;
.container2 {
/* Add styles for Container 2 if needed */
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_title="Demo 1", # String or None. Strings get appended with "• Streamlit".
page_icon=None, # String, anything supported by st.image, or None.
opt_1 = 'parent --> child'
opt_2 = 'child --> parent'
options = [opt_1, opt_2]
rows = None
pos = None
patent_num = ''
claim_num1 = ''
claim_num2 = ''
instruction= ''
input_text = ''
output_text = ''
response = ''
query = ''
score_lst_1 = 0
score_lst_2 = 0
rewards = ''
with st.container():
col1, col2, col3 = st.columns([3, 5, 2])
with col1:
selected_option = st.selectbox('Select a demo:', options)
if selected_option == opt_1:
rows = load_data('demo1')
msg = 'novelty = sim1-sim2'
#msg = 'delta of similarities<br>(sim1-sim2)'
c1_tag = 'pc'
c2_tag = 'cc1'
c3_tag = 'cc2'
elif selected_option == opt_2:
rows = load_data('demo2')
msg = 'similarity of<br>(pc1) and (pc2)'
c1_tag = 'cc'
c2_tag = 'pc1'
c3_tag = 'pc2'
st.text('Unknown option')
#rows = rows[:5000] # for debugging
with col2:
pos = st.slider("", 1, len(rows))
#pos = st.slider("Degree of novelty (Generated v. Actual)", 1, len(rows))
for i in range(pos):
#prompt = '%s' % rows[i]
patent_num = rows[i]['patent_num']
claim_num1 = rows[i]['claim_num1']
claim_num2 = rows[i]['claim_num2']
instruction= rows[i]['instruction']
input_text = rows[i]['input']
output_text = rows[i]['output']
response = rows[i]['response']
query = rows[i]['query']
score_lst_1 = rows[i]['score_lst_1']
score_lst_2 = rows[i]['score_lst_2']
delta = rows[i]['delta']
rewards = rows[i]['rewards']
with col3:
#v = round(float(score_lst_1)-float(score_lst_2), 4)
#v = delta #round(delta,10)
st.markdown("<center><h7>%s<br>%s</h7></center>" % (msg, delta), unsafe_allow_html=True)
# style='text-align: center; color: black;'
# selectbox_placeholder = st.empty()
# selected_option = selectbox_placeholder.selectbox('Select a demo:', options)
# container1 = st.container()
# with st.container():
# col1, col2 = st.columns(2)
# with col1:
# st.write('Caption for first chart')
# with col2:
# st.line_chart((0,1), height=100)
# with st.container():
# col1, col2 = st.columns(2)
# with col1:
# st.write('Caption for second chart')
# with col2:
# st.line_chart((1,0), height=100)
#st.write('patent_num:', patent_num)
# st.write('claim_num1:', claim_num1)
# st.write('claim_num2:', claim_num2)
st.write('(instruction) ', instruction)
with st.container():
with st.container(border=True):
st.write('(%s) [ %s ]\n%s' % (c1_tag, patent_num, input_text))
#st.write('input:' % patent_num)
#st.write('input:\n', input_text)
#container1.markdown("<div class='container1'>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.write('(%s) (actual)' % c2_tag)
with col2:
with st.container(border=True):
st.write('(%s) (generated)' % c3_tag)
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.write('(sim1) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c2_tag, str(score_lst_1)))
with col2:
with st.container(border=True):
st.write('(sim2) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c3_tag, str(score_lst_2)))
#container1.markdown("</div>", unsafe_allow_html=True)
# st.write("In Container 1")
# table_name ="Please Select Table", list_of_tables)
# st.write('output:')
# st.write(output_text)
# st.write('response:')
# st.write(response)
#st.write('query:', query)
# st.write('score_lst_1:', score_lst_1)
# st.write('score_lst_2:', score_lst_2)
# st.write('rewards:', rewards)
# st.text('hello')
# dict_keys(['patent_num', 'claim_num1', 'claim_num2', 'instruction', 'input', 'output', 'query', 'response', 'score_lst_1', 'score_lst_2', 'rewards'])
# st.subheader("Inspecting PatentGPT-J Model Evaluation")
# num_set = set()
# fn_lst = glob.glob(os.path.join(folder, '*'))
# for i, fn in enumerate(fn_lst):
# for prefix in prefix_lst:
# v ='(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
# if v is None:
# v ='(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
# #pdb.set_trace()
# if v is None:
# #pdb.set_trace()
# continue
# v =
# if first_claim_only:
# if v.endswith('_1'):
# num_set.add(v)
# else:
# num_set.add(v)
# num_lst = list(num_set)
# num_lst.sort()
# select_lst = []
# for i, num in enumerate(num_lst):
# all_existed = True
# for prefix in prefix_lst:
# fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
# if os.path.exists(fn) == False:
# all_existed = False
# break
# if all_existed:
# select_lst.append(num)
# select_lst.sort()
# if len(select_lst) == 0:
# st.text('select_lst is empty')
# return
# if dump_pos_data_for_reporting:
# dump_pos_data(prefix_lst, select_lst)
# st.text('Dump data: done')
# return
# # debug
# #base_fn = 'my_gptj_6b_tpu_size_8_11212952_1_forward.json'
# #base_fn = 'pgj_small_text-1_1_forward.json'
# #_ = show_avg(base_fn)
# if enable_summary_button:
# if st.button('Show Summary'):
# st.text('len(select_lst) = %s' % len(select_lst))
# show_overall_summary(prefix_lst, select_lst)
# # if 'num' not in st.session_state:
# # num = random.choice(select_lst)
# # st.session_state['num'] = num
# # set_state('num', num)
# # def set_state(k, v):
# # if k not in st.session_state:
# # st.session_state[ k ] = v
# show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
# selected = st.selectbox("Choose a patent claim", show_patent_lst)
# num = selected.replace(')', '').replace(' (claim ', '_')
# if st.button('Random pick'):
# num = random.choice(select_lst)
# st.text('Selected: %s' % num)
# st.session_state['num'] = num
# avgs = []
# for prefix in prefix_lst:
# base_fn = '%s_%s_forward.json' % (prefix, num)
# one_avg = show_avg(base_fn, model_names[prefix], num)
# if one_avg is not None:
# avgs.append(one_avg)
# # debug
# #pdb.set_trace()
# #return
# #
# data_lst = []
# for i in range(len(avgs[0])):
# row = []
# for j, prefix in enumerate(prefix_lst):
# row.append(avgs[j][i])
# data_lst.append(row)
# df = pd.DataFrame(data_lst, index=['1','2-10','out of top 10'])
# #df = pd.DataFrame(data_lst, index=['1','2-10','11-100','101~'])
# # ], index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
# # [avgs[0][0], avgs[1][0], avgs[2][0]],
# # [avgs[0][1], avgs[1][1], avgs[2][1]],
# # [avgs[0][2], avgs[1][2], avgs[2][2]],
# # [avgs[0][3], avgs[1][3], avgs[2][3]],
# #df = pd.DataFrame([[1,2],[3,1]], columns=['a', 'b'])
# #df = pd.DataFrame([
# # [sum1[0], sum1[1], sum1[2], sum1[3]],
# # [sum2[0], sum2[1], sum2[2], sum2[3]],
# # [sum3[0], sum3[1], sum3[2], sum3[3]],
# # ]) #, index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
# #df = pd.DataFrame.from_dict(sum_all, orient='index')
# #st.line_chart(df)
# #data_canada ="country == 'Canada'")
# #fig =, x='year', y='pop')
# if st.button('Show chart'):
# fig =, barmode='group')
# st.plotly_chart(fig, use_container_width=True)
# #st.area_chart(df)
# #st.bar_chart(df)
# #
# base_fn = '%s_%s_forward.json' % (prefix_lst[ id_to_scroll ], st.session_state['num'])
# result, avg_pick, avg_prob, _, _, _, _, _, _, _, _ = calc_details(base_fn)
# recv = result['recv']
# lst = result['output']
# input_tokens = result['input']
# # (Pdb) print(token_pos_lst[0].keys())
# #dict_keys(['idx', 'gen_text', 'actual_next_token_text', 'actual_next_token_top_seq', 'actual_next_token_top_prob', 'top_n_lst'])
# height = calc_height(recv['context'])
# st.text_area('context:', recv['context'], height=height)
# pos = st.slider("Token position", 0, len(lst))
# prompt = ''
# for i in range(pos+1):
# prompt += input_tokens[i]['text']
# height = calc_height(prompt)
# st.text_area('prompt:', prompt, height=height)
# ch = handle_char_return(lst[pos]['actual_next_token_text'])
# st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1,
# float(lst[pos]['actual_next_token_top_prob'])))
# st.text('top 10 tokens:')
# for i, v in enumerate(lst[pos]['top_n_lst']):
# ch = handle_char_return(v['top_n_text'])
# st.text('[ %s ][ %s ]( %.2f )' % (i+1, ch, float(v['top_n_prob'])))
# gen_text = lst[pos]['gen_text']
# gen_text = remove_end_of_claim_text(gen_text)
# st.text('gen_text: %s' % gen_text)
# #st.text("done. ok.")
# #st.text('result:\n%s' % result)
if __name__ == "__main__":
#def load_data_pre(demo):
# fn = 'ppo_output/'
# with open(fn, 'r') as f:
# rows = json.load(f)
# new_rows = []
# for i, row in enumerate(rows):
# item1 = {}
# item2 = {}
# if demo == 'demo1':
# item1[ 'delta' ] = abs(row['score_lst_1'][0] - row['score_lst_2'][0])
# item2[ 'delta' ] = abs(row['score_lst_1'][1] - row['score_lst_2'][1])
# elif demo == 'demo2':
# #pdb.set_trace()
# item1[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][0], row['response'][0])
# item2[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][1], row['response'][1])
# print('[ %s ] detla = %s' % (i, item1[ 'delta' ]))
# for k in row.keys():
# item1[ k ] = row[ k ][0]
# item2[ k ] = row[ k ][1]
# if demo == 'demo1':
# if item1['instruction'].find('child') > 0:
# new_rows.append(item1)
# if item2['instruction'].find('child') > 0:
# new_rows.append(item2)
# elif demo == 'demo2':
# if item1['instruction'].find('parent') > 0:
# new_rows.append(item1)
# if item2['instruction'].find('parent') > 0:
# new_rows.append(item2)
# # Assuming new_rows is your list of dictionaries
# sorted_rows = sorted(new_rows, key=lambda x: x['delta'])
# # kv = {}
# # for i, row in enumerate(new_rows):
# # if diff > 0.0001:
# # kv[i] = round(diff, 4)
# # sorted_rows = []
# # sorted_kv = sorted(kv.items(), key=lambda x:x[1])
# # for k, v in sorted_kv:
# # sorted_rows.append(new_rows[k])
# #pdb.set_trace()
# return sorted_rows