import streamlit as st
import time
import requests

import os
import json
import glob
import re
import random
import difflib

from random import randrange

prefix_lst = [
  "pgj_d_4096", 
  "pgj_d_2048", 
  "pgj_d_1024_v2", 
  "pgj_d_1024_layer_14", 
  "pgj_d_1024_layer_7", 
  "pgj_d_1024_layer_2", 
  "pgj_d_1024_layer_1" ]

model_names = {
  prefix_lst[0]: 'PatentGPT-J-6B',
  prefix_lst[1]: 'PatentGPT-J-1.6B',
  prefix_lst[2]: 'PatentGPT-J-456M',
  prefix_lst[3]: 'PatentGPT-J-279M',
  prefix_lst[4]: 'PatentGPT-J-191M',
  prefix_lst[5]: 'PatentGPT-J-128M',
  prefix_lst[6]: 'PatentGPT-J-115M',}

# experiment 3
# folder = os.path.join('experiments', 'non_patent')
# id_to_scroll = 1  # which of the above to scroll through
# first_claim_only = True 

#experiment 2
folder = os.path.join('experiments', 'ipg20220104_500')
#folder = "device_serve_results"
id_to_scroll = 1  # which of the above to scroll through
first_claim_only = False

# prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
# #, "pgj_large", "pgj_medium", "pgj_small", ]
# # "pgj_d_1024_layer_14"

# experiment 1
# folder = os.path.join('experiments', 'ipg22_500')
# # (previous) folder = "eval_ipg22_500"
# id_to_scroll = 1  # which of the above to scroll through
# first_claim_only = True 

select_lst = []

def handle_char_return(text):
  if text == '(none)':  # unicorn text
    text == ''

  return text

def calc_height(s):
  return int(len(s) / 10 * 3) + 30

def remove_end_of_claim_text(gen_text):  
  tag = '<|end_of_claim|>'
  pos = gen_text.find(tag)
  if pos > 0:
    gen_text = gen_text[:pos+len(tag)]
    return gen_text

  tag = '<|endoftext|>'
  pos = gen_text.find(tag)
  if pos > 0:
    gen_text = gen_text[:pos+len(tag)]

  return gen_text

def update_content():
  #st.write("The value of the slider is:", st.session_state.myslider)
  pass

def prepare_select_lst():
  num_set = set()
  fn_lst = glob.glob(os.path.join(folder, '*'))
  for i, fn in enumerate(fn_lst):
    for prefix in prefix_lst:    
      v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
      if v is None:
        v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
        if v is None:
          continue

      v = v.group(2)
      if first_claim_only:
        if v.endswith('_1'):
          num_set.add(v)
      else: 
        num_set.add(v)

  num_lst = list(num_set)
  num_lst.sort()

  select_lst = []
  for i, num in enumerate(num_lst):
    all_existed = True
    for prefix in prefix_lst:
      fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
      if os.path.exists(fn) == False:
        all_existed = False
        break
    if all_existed: 
      select_lst.append(num)
  select_lst.sort()

  return select_lst
  
def update_selected():
  global select_lst

  #st.write("The value of the slider is:", st.session_state.myselectbox)
  #num = selected.replace(')', '').replace(' (claim ', '_')
  selected = st.session_state.myselectbox
  pick_and_load(select_lst, selected)

def pick_and_load(select_lst, selected=None):
  if selected is None:
    pick = random.randrange(len(select_lst))
    st.session_state['picked_flag'] = pick
    selected = select_lst[pick]
  num = selected.replace(')', '').replace(' (claim ', '_')
  st.session_state['num'] = num

  prefix = "pgj_d_1024_v2" # size: 456M
  base_fn = '%s_%s_forward.json' % (prefix, num)
  full_fn = os.path.join(folder, base_fn)
  with open(full_fn) as f:
    result = json.loads(f.read())
    print("Loaded: %s" % full_fn)
  st.session_state['result'] = result

  return pick, num, result 

def main():
  st.set_page_config(  # Alternate names: setup_page, page, layout
    layout="wide",  # Can be "centered" or "wide". In the future also "dashboard", etc.
    initial_sidebar_state="auto",  # Can be "auto", "expanded", "collapsed"
    page_title="Patent-GPT-J demo",  # String or None. Strings get appended with "• Streamlit".
    page_icon=None,  # String, anything supported by st.image, or None.
  )
  st.subheader("PatentGPT-J Demo 2 (Autocomplete Effectiveness)")
  st.text("Data coverage: ipg220104 (in 2022-01-04)")

  if 'select_lst' not in st.session_state:
    select_lst = prepare_select_lst()
    st.session_state['select_lst'] = select_lst
  else:
    select_lst = st.session_state['select_lst']

  if len(select_lst) == 0:
    st.text('select_lst is empty')
    return 

  show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]

  #pick = 0
  if 'picked_flag' not in st.session_state:
    pick, num, result = pick_and_load(select_lst)
  else:
    pick = st.session_state['picked_flag']
    num = st.session_state['num']
    result = st.session_state['result']
    
  if st.button('Random pick'):
    pick, num, result = pick_and_load(select_lst)

  # to-do, on_change --> load the file
  selected = st.selectbox("Choose a patent claim", show_patent_lst, index=pick, key='myselectbox', on_change=update_selected)

  #st.text('Selected: %s' % num)
  recv = result['recv']
  lst = result['output']
  input_tokens = result['input']

  height = calc_height(recv['context'])
  st.text_area('context:', recv['context'], height=height)

  pos = st.slider("Token position", 0, len(lst), key="myslider", on_change=update_content)
  prompt = ''
  for i in range(pos+1):
    prompt += input_tokens[i]['text']
  height = calc_height(prompt)
  st.text_area('prompt:', prompt, height=height)

  ch = handle_char_return(lst[pos]['actual_next_token_text'])
  st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)    top 10 tokens:' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1, 
    float(lst[pos]['actual_next_token_top_prob'])))

  msg = ''
  for i, v in enumerate(lst[pos]['top_n_lst']):
    ch = handle_char_return(v['top_n_text'])
    msg += '(%s)[%s](%.2f)   ' % (i+1, ch, float(v['top_n_prob']))
    if i == 4:
      st.text(msg)
      msg = ''
  st.text(msg)

  gen_text = lst[pos]['gen_text']
  gen_text = remove_end_of_claim_text(gen_text)

  height = calc_height(gen_text)
  st.text_area('generated:', gen_text, height=height)
  #st.text('gen_text: %s' % gen_text)
  #st.text("done. ok.")
  #st.text('result:\n%s' % result)

if __name__ == "__main__":
  main()