patent commited on
Commit
d941617
·
verified ·
1 Parent(s): 42d8bd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +868 -0
app.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import json
5
+
6
+ import time
7
+ import requests
8
+
9
+ import os
10
+ import glob
11
+ import re
12
+ import smart_open
13
+ import plotly.express as px
14
+ import random
15
+ import difflib
16
+ import pdb
17
+
18
+ from sentence_transformers import SentenceTransformer, models, util
19
+
20
+ enable_summary_button = True
21
+ dump_pos_data_for_reporting = True
22
+
23
+ bucket_name = "paper_n1"
24
+
25
+ prefix_lst = [
26
+ "pgj_d_4096",
27
+ "pgj_d_2048",
28
+ "pgj_d_1024_v2",
29
+ "pgj_d_1024_layer_14",
30
+ "pgj_d_1024_layer_7",
31
+ "pgj_d_1024_layer_2",
32
+ "pgj_d_1024_layer_1" ]
33
+
34
+ # "my_gptj_6b_tpu_size_8",
35
+
36
+ model_names = {
37
+ prefix_lst[0]: 'PatentGPT-J-6B',
38
+ prefix_lst[1]: 'PatentGPT-J-1.6B',
39
+
40
+ # prefix_lst[2]: 'PatentGPT-J-279M',
41
+ # prefix_lst[3]: 'PatentGPT-J-191M',
42
+ # prefix_lst[4]: 'PatentGPT-J-128M',
43
+ # prefix_lst[5]: 'PatentGPT-J-115M',}
44
+
45
+ prefix_lst[2]: 'PatentGPT-J-456M',
46
+ prefix_lst[3]: 'PatentGPT-J-279M',
47
+ prefix_lst[4]: 'PatentGPT-J-191M',
48
+ prefix_lst[5]: 'PatentGPT-J-128M',
49
+ prefix_lst[6]: 'PatentGPT-J-115M',}
50
+
51
+ # prefix_lst[7]:'GPT-J-6B'
52
+
53
+ # experiment 3
54
+ # folder = os.path.join('experiments', 'non_patent')
55
+ # id_to_scroll = 1 # which of the above to scroll through
56
+ # first_claim_only = True
57
+
58
+ #experiment 2
59
+ # folder = os.path.join('experiments', 'ipg20220104_500')
60
+ # #folder = "device_serve_results"
61
+ # id_to_scroll = 1 # which of the above to scroll through
62
+ # first_claim_only = False
63
+
64
+ # prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
65
+ # #, "pgj_large", "pgj_medium", "pgj_small", ]
66
+ # # "pgj_d_1024_layer_14"
67
+
68
+ # experiment 1
69
+ folder = os.path.join('experiments', 'ipg22_500')
70
+ # (previous) folder = "eval_ipg22_500"
71
+ id_to_scroll = 1 # which of the above to scroll through
72
+ first_claim_only = True
73
+
74
+ ignore_outscope = True # ignore pick > 10
75
+
76
+ def show_diff(a, b):
77
+ #print('{} => {}'.format(a,b))
78
+ for i, s in enumerate(difflib.ndiff(a, b)):
79
+ if s[0]==' ': continue
80
+ elif s[0]=='-':
81
+ print(u'Delete "{}" from position {}'.format(s[-1],i))
82
+ elif s[0]=='+':
83
+ print(u'Add "{}" to position {}'.format(s[-1],i))
84
+
85
+ def handle_char_return(text):
86
+ if text == '(none)': # unicorn text
87
+ text == ''
88
+
89
+ return text
90
+
91
+ #return ch.replace('\n', '\\n')
92
+
93
+ #if ch == '\n':
94
+ # ch = "'\\n'"
95
+ #return ch
96
+
97
+ def get_remaining(lst, pos):
98
+ s = ''
99
+ for i in range(pos, len(lst)):
100
+ text = lst[i]['actual_next_token_text']
101
+ if text.startswith(' ') == False:
102
+ s += text
103
+ else:
104
+ break
105
+
106
+ return s
107
+
108
+ def calc_details(base_fn):
109
+ full_fn = os.path.join(folder, base_fn)
110
+ #gs_fn = "gs://%s/%s/%s" % (bucket_name, folder, base_fn)
111
+ #with smart_open.open(gs_fn) as f:
112
+
113
+ if os.path.exists(full_fn) == False:
114
+ return None, -1, -1, None, None, None, None, None
115
+
116
+ with open(full_fn) as f:
117
+ result = json.loads(f.read())
118
+ print("Loaded: %s" % full_fn)
119
+
120
+ lst = result['output']
121
+ recv = result['recv']
122
+ sum_pick = 0
123
+ sum_prob = 0
124
+ sum_outscope_count = 0
125
+ sum_outscope_len = 0
126
+ sum_hit_1 = 0
127
+ sum_top_10_len = 0
128
+ full_text = ''
129
+
130
+ token_count = 0
131
+ #found_end = False
132
+
133
+ #pdb.set_trace()
134
+
135
+ for i, tk in enumerate(lst[:-1]):
136
+ # if found_end:
137
+ # break
138
+
139
+ token_text = handle_char_return(tk['actual_next_token_text'])
140
+
141
+ # Due to tokenizer difference, the following needs more work in the future.
142
+ # if base_fn.find('gptj') >= 0:
143
+ # # using the original gpt-j-6b model
144
+ # # need to skip special tokens
145
+ # if i <= 7:
146
+ # continue # skip |start of claim|>
147
+
148
+ # remaining_text = get_remaining(lst, i)
149
+ # if remaining_text.find('<|end_of_claim|>') >= 0:
150
+ # pos1 = remaining_text.find('<|end_of_claim|>')
151
+ # token_text = remaining_text[:pos1]
152
+ # found_end = True
153
+ # #pdb.set_trace()
154
+ # #break
155
+
156
+ # The following was for GPT-J-6B. Not needed for PatentGPT-J.
157
+ #if token_text.find('<|end_of_claim|>') == 0:
158
+ # #pdb.set_trace()
159
+ # break
160
+
161
+ next_top_seq = int(tk['actual_next_token_top_seq'])
162
+ next_top_prob = float(tk['actual_next_token_top_prob'])
163
+
164
+ full_text += token_text
165
+ if next_top_seq == 0:
166
+ sum_hit_1 += 1 # press "tab" for the top pick
167
+
168
+ if ignore_outscope and next_top_seq>=10:
169
+ sum_outscope_count += 1
170
+ sum_outscope_len += len(token_text) # use length as keystrokes
171
+ else:
172
+ sum_pick += min(next_top_seq+1, len(token_text))
173
+ #sum_pick += (next_top_seq+1) # press "down" & "tab"
174
+ sum_prob += next_top_prob
175
+ sum_top_10_len += len(token_text)
176
+
177
+ token_count += 1
178
+
179
+ if ignore_outscope:
180
+ if token_count == 0: # unlikely
181
+ avg_pick = 0
182
+ avg_prob = 0
183
+ else:
184
+ avg_pick = float(sum_pick) / token_count
185
+ avg_prob = float(sum_prob) / token_count
186
+ else:
187
+ avg_pick = float(sum_pick) / token_count
188
+ avg_prob = float(sum_prob) / token_count
189
+
190
+ # if len(lst) < 2048: # for debugging
191
+ # s = '<|start_of_claim|>' + full_text
192
+ # if len(s) != len(recv['context']):
193
+ # print('length mismatch --> full_text: %s, recv: %s' % (len(s), len(recv['context'])))
194
+ # show_diff(s, recv['context'])
195
+ # pdb.set_trace()
196
+
197
+ return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text
198
+
199
+ def show_avg(base_fn, model_name, patent_claim_num, show_pick=False):
200
+
201
+ result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
202
+
203
+ if token_count == 0:
204
+ print('debug 2')
205
+ pdb.set_trace()
206
+
207
+ if result is None:
208
+ return None
209
+
210
+ lst = result['output']
211
+ result = ''
212
+ sum_all = {}
213
+ for i, tk in enumerate(lst):
214
+ token_text = handle_char_return(tk['actual_next_token_text'])
215
+
216
+ if token_text == '<|end_of_claim|>':
217
+ break
218
+
219
+ if token_text == '(none)': # for unicorn text
220
+ break
221
+
222
+ # Skip GPT-J, due to different tokenization
223
+ # if base_fn.find('gptj') >= 0:
224
+ # # using the original gpt-j-6b model
225
+ # # need to skip special tokens
226
+ # if i <= 7:
227
+ # continue # skip |start of claim|>
228
+ # if token_text == '.<': # assuming .<|end of claim|>
229
+ # break
230
+
231
+ pick = int(tk['actual_next_token_top_seq'])
232
+ prob = float(tk['actual_next_token_top_prob'])
233
+
234
+ colors = [
235
+ ['00ff00', '000000', '1'],
236
+ ['008800', 'ffffff', '2-10'],
237
+ ['ff0000', 'ffffff', 'out of top 10'],
238
+ ]
239
+ #colors = [
240
+ # ['00ff00', '000000', '1'],
241
+ # ['008800', 'ffffff', '2-10'],
242
+ # ['aa0000', 'ffffff', '11-100'],
243
+ # ['ff0000', 'ffffff', '101~']
244
+ #]
245
+
246
+ for j, item in enumerate(colors):
247
+ sum_all[item[2]] = 0
248
+
249
+ # skip follow-up subword
250
+ # if token_text.startswith(' ') == False:
251
+ # bg_color = ''
252
+ # fg_color = ''
253
+ # else:
254
+
255
+ if pick == 0:
256
+ bg_color = colors[0][0]
257
+ fg_color = colors[0][1]
258
+ tag = colors[0][2]
259
+ sum_all[tag] += 1
260
+ elif pick >= 1 and pick < 10:
261
+ bg_color = colors[1][0]
262
+ fg_color = colors[1][1]
263
+ tag = colors[1][2]
264
+ sum_all[tag] += 1
265
+ else: # pick >= 10
266
+ #elif pick >= 10 and pick < 100:
267
+ bg_color = colors[2][0]
268
+ fg_color = colors[2][1]
269
+ tag = colors[2][2]
270
+ sum_all[tag] += 1
271
+ #else: #pick >= 100:
272
+ # bg_color = colors[3][0]
273
+ # fg_color = colors[3][1]
274
+ # tag = colors[3][2]
275
+ # sum_all[tag] += 1
276
+
277
+ if show_pick:
278
+ pick = '[%s]' % pick
279
+ else:
280
+ pick = ''
281
+
282
+ result += "<span style=background-color:#%s;color:#%s;border-radius:5px;>%s%s</span> " % (bg_color, fg_color, token_text, pick) #&nbsp;
283
+
284
+ color_msg = ''
285
+ for i, v in enumerate(colors):
286
+ color_msg += "<span style=background-color:#%s;color:#%s;border-radius:5px;>&nbsp;%s&nbsp;</span> " % (v[0], v[1], v[2])
287
+
288
+ #result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
289
+
290
+ # sum_pick as top 1~10
291
+ keys_with_auto = (sum_pick+sum_outscope_len)
292
+ keys_without_auto = len(full_text)
293
+ saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
294
+ s = 'model: %s\n' \
295
+ 'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \
296
+ 'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \
297
+ 'Keystroke distribution: top 1~10: %s (top 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto, sum_pick, sum_hit_1, sum_outscope_len)
298
+ st.text(s)
299
+
300
+ # s = 'file: %s, sum_pick: %s, sum_hit_1: %s, token_count: %s, sum_outscope: %s, avg_pick: %.2f, avg_prob: %.2f, sum_prob: %.2f, hit_1 ratio: %.2f &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;' % (base_fn, sum_pick, sum_hit_1, token_count, sum_outscope, avg_pick, avg_prob, sum_prob, float(sum_hit_1)/token_count)
301
+ #s += color_msg
302
+
303
+ s = color_msg
304
+ st.markdown(s, unsafe_allow_html=True)
305
+ #st.text('file: %s, avg_pick: %5.2f, avg_prob: %.2f, hit count: %s/%s ' % (base_fn, avg_pick, avg_prob, hit_0_count, len(lst)))
306
+ # show histogram
307
+
308
+ st.markdown(result, unsafe_allow_html=True)
309
+ #st.text_area('context with top seq & prob:', result, height=400)
310
+
311
+ sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']]
312
+ #sum_lst = [['1', sum_all['1']], ['2-10', sum_all['2-10']]]
313
+ #sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['11-100'], sum_all['101~']]
314
+
315
+ return sum_lst
316
+
317
+ def show_overall_summary(prefix_lst, select_lst):
318
+ # accumulate all
319
+
320
+ # debug
321
+ # for i, num in enumerate(select_lst):
322
+ # pre_full_text = ''
323
+ # for prefix in prefix_lst:
324
+ # base_fn = '%s_%s_forward.json' % (prefix, num)
325
+ # result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
326
+
327
+ # if pre_full_text == '':
328
+ # pre_full_text = full_text
329
+ # else:
330
+ # if pre_full_text != full_text:
331
+ # print('debug')
332
+ # pdb.set_trace()
333
+
334
+ # #
335
+ # pdb.set_trace()
336
+
337
+ for prefix in prefix_lst:
338
+ acc_token_count = 0
339
+ acc_sum_pick = 0
340
+ acc_sum_prob = 0
341
+ acc_sum_outscope_count = 0
342
+ acc_sum_outscope_len = 0
343
+ acc_sum_hit_1 = 0
344
+ acc_sum_top_10_len = 0
345
+ acc_full_text_len = 0
346
+
347
+ pre_full_text = ''
348
+ for i, num in enumerate(select_lst):
349
+ base_fn = '%s_%s_forward.json' % (prefix, num)
350
+ result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
351
+
352
+ acc_token_count += token_count
353
+ acc_sum_pick += sum_pick
354
+ acc_sum_prob += sum_prob
355
+ acc_sum_outscope_count += sum_outscope_count
356
+ acc_sum_outscope_len += sum_outscope_len
357
+ acc_sum_hit_1 += sum_hit_1
358
+ acc_sum_top_10_len += sum_top_10_len
359
+ acc_full_text_len += len(full_text)
360
+
361
+ if acc_token_count > 0:
362
+ # acc_sum_pick --> top 1~10
363
+ keys_with_auto = acc_sum_pick + acc_sum_outscope_len
364
+ keys_without_auto = acc_full_text_len
365
+ saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
366
+
367
+ st.text('[ %s ]\n' \
368
+ 'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \
369
+ '(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \
370
+ 'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % (
371
+ model_names[prefix], saved_ratio,
372
+ '{:,}'.format(keys_with_auto),
373
+ '{:,}'.format(acc_sum_pick),
374
+ '{:,}'.format(acc_sum_outscope_len),
375
+ '{:,}'.format(acc_sum_hit_1),
376
+ '{:,}'.format(keys_without_auto),
377
+ '{:,}'.format(acc_sum_top_10_len),
378
+ acc_sum_prob,
379
+ ))
380
+
381
+ st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto)))
382
+
383
+ # st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % (
384
+ # acc_token_count,
385
+ # float(acc_sum_hit_1)/acc_token_count,
386
+ # float(acc_sum_pick)/acc_token_count,
387
+ # float(acc_sum_prob)/acc_token_count,
388
+ # float(acc_sum_outscope_count)/acc_token_count))
389
+
390
+ def calc_height(s):
391
+ return int(len(s) / 10 * 3) + 30
392
+
393
+ def remove_end_of_claim_text(gen_text):
394
+ tag = '<|end_of_claim|>'
395
+ pos = gen_text.find(tag)
396
+ if pos > 0:
397
+ gen_text = gen_text[:pos+len(tag)]
398
+ return gen_text
399
+
400
+ tag = '<|endoftext|>'
401
+ pos = gen_text.find(tag)
402
+ if pos > 0:
403
+ gen_text = gen_text[:pos+len(tag)]
404
+
405
+ return gen_text
406
+
407
+ def dump_pos_data(prefix_lst, select_lst):
408
+ #statistics = [[0]*3]*2048
409
+ statistics = []
410
+ for i in range(2048):
411
+ statistics.append([0,0,0])
412
+
413
+ #results.append(['model', 'pos', 'key'])
414
+ #results.append(['model', 'patent_claim', 'pos', 'top-1', 'top-2~10', 'out of top 10'])
415
+ max_len = -1
416
+ for prefix in prefix_lst:
417
+ model_name = model_names[prefix].replace('PatentGPT-J-', '')
418
+ if model_name != '456M':
419
+ continue
420
+
421
+ #total = {}
422
+ for i, num in enumerate(select_lst):
423
+ base_fn = '%s_%s_forward.json' % (prefix, num)
424
+ full_fn = os.path.join(folder, base_fn)
425
+ if os.path.exists(full_fn) == False:
426
+ continue
427
+
428
+ with open(full_fn) as f:
429
+ result = json.loads(f.read())
430
+ print("Loaded: %s" % full_fn)
431
+
432
+ lst = result['output']
433
+ for j, tk in enumerate(lst[:-1]):
434
+ max_len = max(j, max_len)
435
+ next_top_seq = int(tk['actual_next_token_top_seq'])
436
+ #next_top_prob = float(tk['actual_next_token_top_prob'])
437
+
438
+ top_1 = top_2_to_10 = out_of_scope = 0
439
+ if next_top_seq == 0:
440
+ top_1 = 1
441
+ tag = 'top-1'
442
+ statistics[j][0] += 1
443
+ elif next_top_seq > 0 and next_top_seq < 10:
444
+ top_2_to_10 = 1
445
+ tag = 'top-2~10'
446
+ statistics[j][1] += 1
447
+ else:
448
+ out_of_scope = 1
449
+ tag = 'out-of-scope'
450
+ statistics[j][2] += 1
451
+
452
+ #total[tag] = total.get(tag, 0) + 1
453
+ #results.append([model_name, str(i+1), tag])
454
+ #results.append([model_name, str(i+1), tag])
455
+ #results.append([model_name, num, str(i+1), tag])
456
+ #results.append([model_name, num, i+1, top_1, top_2_to_10, out_of_scope])
457
+ #pdb.set_trace()
458
+ #pdb.set_trace()
459
+
460
+ dump_file = 'dump4.txt'
461
+ #pdb.set_trace()
462
+ with open(dump_file, 'w') as f:
463
+ for i in range(max_len+1):
464
+ f.write('%s, top-1, %s\n' % (i+1, statistics[i][0]))
465
+ f.write('%s, top-2~10, %s\n' % (i+1, statistics[i][1]))
466
+ f.write('%s, out_of_scope, %s\n' % (i+1, statistics[i][2]))
467
+ # f.write('%s\n' % ', '.join([str(i+1)] + [ str(v) for v in statistics[i] ] ))
468
+ print('saved: %s' % dump_file)
469
+
470
+
471
+ # dump_file = 'dump2.txt'
472
+ # with open(dump_file, 'w') as f:
473
+ # for line in results:
474
+ # f.write('%s\n' % ', '.join(line))
475
+ # print('saved: %s' % dump_file)
476
+
477
+
478
+ def calc_sentence_similarity(sent_model, sent1, sent2):
479
+ rewards = []
480
+ embedding1 = sent_model.encode(sent1, convert_to_tensor=True)
481
+ embedding2 = sent_model.encode(sent2, convert_to_tensor=True)
482
+ similarity = util.cos_sim(embedding1, embedding2)[0][0]
483
+
484
+ #pdb.set_trace()
485
+
486
+ return similarity
487
+
488
+ sent_model = 'patent/st-aipd-nlp-g'
489
+ print('loading SentenceTransformer: %s' % sent_model)
490
+ sent_aipd = SentenceTransformer(sent_model)
491
+
492
+ def load_data(demo):
493
+ fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.delta.txt'
494
+ with open(fn, 'r') as f:
495
+ rows = json.load(f)
496
+
497
+ if demo == 'demo1':
498
+ new_rows = [ row for row in rows if row['instruction'].find('child') > 0 ]
499
+ elif demo == 'demo2':
500
+ new_rows = [ row for row in rows if row['instruction'].find('parent') > 0 ]
501
+ else:
502
+ new_rows = []
503
+
504
+ return new_rows
505
+
506
+ container_style = """
507
+ <style>
508
+ .container1 {
509
+ border: 2px solid #3498db;
510
+ border-radius: 8px;
511
+ padding: 10px;
512
+ margin-bottom: 20px;
513
+ }
514
+ .container2 {
515
+ /* Add styles for Container 2 if needed */
516
+ }
517
+ </style>
518
+ """
519
+
520
+ def main():
521
+ st.set_page_config( # Alternate names: setup_page, page, layout
522
+ layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
523
+ initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
524
+ page_title="Demo 1", # String or None. Strings get appended with "• Streamlit".
525
+ page_icon=None, # String, anything supported by st.image, or None.
526
+ )
527
+
528
+ opt_1 = 'parent --> child'
529
+ opt_2 = 'child --> parent'
530
+ options = [opt_1, opt_2]
531
+ rows = None
532
+ pos = None
533
+ patent_num = ''
534
+ claim_num1 = ''
535
+ claim_num2 = ''
536
+ instruction= ''
537
+ input_text = ''
538
+ output_text = ''
539
+ response = ''
540
+ query = ''
541
+ score_lst_1 = 0
542
+ score_lst_2 = 0
543
+ rewards = ''
544
+ with st.container():
545
+ col1, col2, col3 = st.columns([3, 5, 2])
546
+ with col1:
547
+ selected_option = st.selectbox('Select a demo:', options)
548
+ if selected_option == opt_1:
549
+ rows = load_data('demo1')
550
+ msg = 'novelty = sim1-sim2'
551
+ #msg = 'delta of similarities<br>(sim1-sim2)'
552
+ c1_tag = 'pc'
553
+ c2_tag = 'cc1'
554
+ c3_tag = 'cc2'
555
+ elif selected_option == opt_2:
556
+ rows = load_data('demo2')
557
+ msg = 'similarity of<br>(pc1) and (pc2)'
558
+ c1_tag = 'cc'
559
+ c2_tag = 'pc1'
560
+ c3_tag = 'pc2'
561
+ else:
562
+ st.text('Unknown option')
563
+ return
564
+ #rows = rows[:5000] # for debugging
565
+
566
+ with col2:
567
+ pos = st.slider("", 1, len(rows))
568
+ #pos = st.slider("Degree of novelty (Generated v. Actual)", 1, len(rows))
569
+ for i in range(pos):
570
+ #prompt = '%s' % rows[i]
571
+ #pdb.set_trace()
572
+
573
+ patent_num = rows[i]['patent_num']
574
+ claim_num1 = rows[i]['claim_num1']
575
+ claim_num2 = rows[i]['claim_num2']
576
+ instruction= rows[i]['instruction']
577
+ input_text = rows[i]['input']
578
+ output_text = rows[i]['output']
579
+ response = rows[i]['response']
580
+ query = rows[i]['query']
581
+ score_lst_1 = rows[i]['score_lst_1']
582
+ score_lst_2 = rows[i]['score_lst_2']
583
+ delta = rows[i]['delta']
584
+ rewards = rows[i]['rewards']
585
+ with col3:
586
+ #v = round(float(score_lst_1)-float(score_lst_2), 4)
587
+ #v = delta #round(delta,10)
588
+ st.markdown("<center><h7>%s<br>%s</h7></center>" % (msg, delta), unsafe_allow_html=True)
589
+ # style='text-align: center; color: black;'
590
+
591
+
592
+ # selectbox_placeholder = st.empty()
593
+ # selected_option = selectbox_placeholder.selectbox('Select a demo:', options)
594
+ # container1 = st.container()
595
+
596
+
597
+ # with st.container():
598
+ # col1, col2 = st.columns(2)
599
+ # with col1:
600
+ # st.write('Caption for first chart')
601
+ # with col2:
602
+ # st.line_chart((0,1), height=100)
603
+ # with st.container():
604
+ # col1, col2 = st.columns(2)
605
+ # with col1:
606
+ # st.write('Caption for second chart')
607
+ # with col2:
608
+ # st.line_chart((1,0), height=100)
609
+
610
+ #st.write('patent_num:', patent_num)
611
+ # st.write('claim_num1:', claim_num1)
612
+ # st.write('claim_num2:', claim_num2)
613
+ st.write('(instruction) ', instruction)
614
+
615
+ with st.container():
616
+ with st.container(border=True):
617
+ st.write('(%s) [ %s ]\n%s' % (c1_tag, patent_num, input_text))
618
+ #st.write('input:' % patent_num)
619
+ #st.write('input:\n', input_text)
620
+
621
+ #container1.markdown("<div class='container1'>", unsafe_allow_html=True)
622
+ col1, col2 = st.columns(2)
623
+ with col1:
624
+ with st.container(border=True):
625
+ st.write('(%s) (actual)' % c2_tag)
626
+ st.write(output_text)
627
+ with col2:
628
+ with st.container(border=True):
629
+ st.write('(%s) (generated)' % c3_tag)
630
+ st.write(response)
631
+
632
+ col1, col2 = st.columns(2)
633
+ with col1:
634
+ with st.container(border=True):
635
+ st.write('(sim1) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c2_tag, str(score_lst_1)))
636
+ with col2:
637
+ with st.container(border=True):
638
+ st.write('(sim2) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c3_tag, str(score_lst_2)))
639
+
640
+ #container1.markdown("</div>", unsafe_allow_html=True)
641
+
642
+ # st.write("In Container 1")
643
+ # table_name = st.radio("Please Select Table", list_of_tables)
644
+
645
+ # st.write('output:')
646
+ # st.write(output_text)
647
+ # st.write('response:')
648
+ # st.write(response)
649
+ #st.write('query:', query)
650
+ # st.write('score_lst_1:', score_lst_1)
651
+ # st.write('score_lst_2:', score_lst_2)
652
+ # st.write('rewards:', rewards)
653
+ # st.text('hello')
654
+
655
+ # dict_keys(['patent_num', 'claim_num1', 'claim_num2', 'instruction', 'input', 'output', 'query', 'response', 'score_lst_1', 'score_lst_2', 'rewards'])
656
+
657
+ # st.subheader("Inspecting PatentGPT-J Model Evaluation")
658
+
659
+
660
+
661
+ # num_set = set()
662
+ # fn_lst = glob.glob(os.path.join(folder, '*'))
663
+ # for i, fn in enumerate(fn_lst):
664
+ # for prefix in prefix_lst:
665
+ # v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
666
+ # if v is None:
667
+ # v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
668
+
669
+ # #pdb.set_trace()
670
+ # if v is None:
671
+ # #pdb.set_trace()
672
+ # continue
673
+
674
+ # v = v.group(2)
675
+ # if first_claim_only:
676
+ # if v.endswith('_1'):
677
+ # num_set.add(v)
678
+ # else:
679
+ # num_set.add(v)
680
+
681
+ # num_lst = list(num_set)
682
+ # num_lst.sort()
683
+
684
+ # select_lst = []
685
+ # for i, num in enumerate(num_lst):
686
+ # all_existed = True
687
+ # for prefix in prefix_lst:
688
+ # fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
689
+ # if os.path.exists(fn) == False:
690
+ # all_existed = False
691
+ # break
692
+ # if all_existed:
693
+ # select_lst.append(num)
694
+ # select_lst.sort()
695
+
696
+ # if len(select_lst) == 0:
697
+ # st.text('select_lst is empty')
698
+ # return
699
+
700
+ # if dump_pos_data_for_reporting:
701
+ # dump_pos_data(prefix_lst, select_lst)
702
+ # st.text('Dump data: done')
703
+ # return
704
+
705
+ # # debug
706
+ # #base_fn = 'my_gptj_6b_tpu_size_8_11212952_1_forward.json'
707
+ # #base_fn = 'pgj_small_text-1_1_forward.json'
708
+ # #_ = show_avg(base_fn)
709
+
710
+ # if enable_summary_button:
711
+ # if st.button('Show Summary'):
712
+ # st.text('len(select_lst) = %s' % len(select_lst))
713
+ # show_overall_summary(prefix_lst, select_lst)
714
+
715
+ # # if 'num' not in st.session_state:
716
+ # # num = random.choice(select_lst)
717
+ # # st.session_state['num'] = num
718
+
719
+ # # set_state('num', num)
720
+ # # def set_state(k, v):
721
+ # # if k not in st.session_state:
722
+ # # st.session_state[ k ] = v
723
+
724
+ # show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
725
+ # selected = st.selectbox("Choose a patent claim", show_patent_lst)
726
+ # num = selected.replace(')', '').replace(' (claim ', '_')
727
+ # if st.button('Random pick'):
728
+ # num = random.choice(select_lst)
729
+
730
+ # st.text('Selected: %s' % num)
731
+ # st.session_state['num'] = num
732
+
733
+ # avgs = []
734
+ # for prefix in prefix_lst:
735
+ # base_fn = '%s_%s_forward.json' % (prefix, num)
736
+ # one_avg = show_avg(base_fn, model_names[prefix], num)
737
+ # if one_avg is not None:
738
+ # avgs.append(one_avg)
739
+
740
+ # # debug
741
+ # #pdb.set_trace()
742
+ # #return
743
+ # #
744
+
745
+ # data_lst = []
746
+ # for i in range(len(avgs[0])):
747
+ # row = []
748
+ # for j, prefix in enumerate(prefix_lst):
749
+ # row.append(avgs[j][i])
750
+ # data_lst.append(row)
751
+
752
+ # df = pd.DataFrame(data_lst, index=['1','2-10','out of top 10'])
753
+ # #df = pd.DataFrame(data_lst, index=['1','2-10','11-100','101~'])
754
+
755
+ # # ], index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
756
+ # # [avgs[0][0], avgs[1][0], avgs[2][0]],
757
+ # # [avgs[0][1], avgs[1][1], avgs[2][1]],
758
+ # # [avgs[0][2], avgs[1][2], avgs[2][2]],
759
+ # # [avgs[0][3], avgs[1][3], avgs[2][3]],
760
+
761
+ # #df = pd.DataFrame([[1,2],[3,1]], columns=['a', 'b'])
762
+ # #df = pd.DataFrame([
763
+ # # [sum1[0], sum1[1], sum1[2], sum1[3]],
764
+ # # [sum2[0], sum2[1], sum2[2], sum2[3]],
765
+ # # [sum3[0], sum3[1], sum3[2], sum3[3]],
766
+ # # ]) #, index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
767
+ # #df = pd.DataFrame.from_dict(sum_all, orient='index')
768
+ # #st.line_chart(df)
769
+
770
+ # #data_canada = px.data.gapminder().query("country == 'Canada'")
771
+ # #fig = px.bar(data_canada, x='year', y='pop')
772
+
773
+ # if st.button('Show chart'):
774
+ # fig = px.bar(df, barmode='group')
775
+ # st.plotly_chart(fig, use_container_width=True)
776
+ # #fig.show()
777
+ # #st.area_chart(df)
778
+ # #st.bar_chart(df)
779
+
780
+ # #
781
+ # base_fn = '%s_%s_forward.json' % (prefix_lst[ id_to_scroll ], st.session_state['num'])
782
+ # result, avg_pick, avg_prob, _, _, _, _, _, _, _, _ = calc_details(base_fn)
783
+ # recv = result['recv']
784
+ # lst = result['output']
785
+ # input_tokens = result['input']
786
+
787
+ # # (Pdb) print(token_pos_lst[0].keys())
788
+ # #dict_keys(['idx', 'gen_text', 'actual_next_token_text', 'actual_next_token_top_seq', 'actual_next_token_top_prob', 'top_n_lst'])
789
+
790
+ # height = calc_height(recv['context'])
791
+ # st.text_area('context:', recv['context'], height=height)
792
+
793
+ # pos = st.slider("Token position", 0, len(lst))
794
+ # prompt = ''
795
+ # for i in range(pos+1):
796
+ # prompt += input_tokens[i]['text']
797
+ # height = calc_height(prompt)
798
+ # st.text_area('prompt:', prompt, height=height)
799
+
800
+ # ch = handle_char_return(lst[pos]['actual_next_token_text'])
801
+ # st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1,
802
+ # float(lst[pos]['actual_next_token_top_prob'])))
803
+
804
+ # st.text('top 10 tokens:')
805
+ # for i, v in enumerate(lst[pos]['top_n_lst']):
806
+ # ch = handle_char_return(v['top_n_text'])
807
+ # st.text('[ %s ][ %s ]( %.2f )' % (i+1, ch, float(v['top_n_prob'])))
808
+
809
+ # gen_text = lst[pos]['gen_text']
810
+ # gen_text = remove_end_of_claim_text(gen_text)
811
+
812
+ # st.text('gen_text: %s' % gen_text)
813
+ # #st.text("done. ok.")
814
+ # #st.text('result:\n%s' % result)
815
+
816
+ if __name__ == "__main__":
817
+ main()
818
+
819
+ #def load_data_pre(demo):
820
+ # fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.keep.txt'
821
+ # with open(fn, 'r') as f:
822
+ # rows = json.load(f)
823
+
824
+ # new_rows = []
825
+ # for i, row in enumerate(rows):
826
+ # item1 = {}
827
+ # item2 = {}
828
+ # if demo == 'demo1':
829
+ # item1[ 'delta' ] = abs(row['score_lst_1'][0] - row['score_lst_2'][0])
830
+ # item2[ 'delta' ] = abs(row['score_lst_1'][1] - row['score_lst_2'][1])
831
+ # elif demo == 'demo2':
832
+ # #pdb.set_trace()
833
+ # item1[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][0], row['response'][0])
834
+ # item2[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][1], row['response'][1])
835
+
836
+ # print('[ %s ] detla = %s' % (i, item1[ 'delta' ]))
837
+
838
+ # for k in row.keys():
839
+ # item1[ k ] = row[ k ][0]
840
+ # item2[ k ] = row[ k ][1]
841
+
842
+ # if demo == 'demo1':
843
+ # if item1['instruction'].find('child') > 0:
844
+ # new_rows.append(item1)
845
+ # if item2['instruction'].find('child') > 0:
846
+ # new_rows.append(item2)
847
+ # elif demo == 'demo2':
848
+ # if item1['instruction'].find('parent') > 0:
849
+ # new_rows.append(item1)
850
+ # if item2['instruction'].find('parent') > 0:
851
+ # new_rows.append(item2)
852
+
853
+ # # Assuming new_rows is your list of dictionaries
854
+ # sorted_rows = sorted(new_rows, key=lambda x: x['delta'])
855
+
856
+ # # kv = {}
857
+ # # for i, row in enumerate(new_rows):
858
+ # # if diff > 0.0001:
859
+ # # kv[i] = round(diff, 4)
860
+
861
+ # # sorted_rows = []
862
+ # # sorted_kv = sorted(kv.items(), key=lambda x:x[1])
863
+ # # for k, v in sorted_kv:
864
+ # # sorted_rows.append(new_rows[k])
865
+
866
+ # #pdb.set_trace()
867
+
868
+ # return sorted_rows