lorenpe2 commited on
Commit
349d1a2
·
1 Parent(s): 822e1b3

FEAT: Improved diagnostic mode with better output matrix

Browse files
Files changed (1) hide show
  1. app.py +38 -18
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import glob
3
  import json
4
  from typing import Dict, List, Tuple, Union
@@ -8,10 +10,9 @@ import pandas
8
  import streamlit as st
9
  import matplotlib.pyplot as plt
10
 
11
-
12
  from inference_tokenizer import NextSentencePredictionTokenizer
13
  from models import get_class
14
- from models import OwnBertForNextSentencePrediction
15
 
16
  def get_model(_model_path):
17
  print(f"Getting model at {_model_path}")
@@ -57,7 +58,6 @@ for model_path in models_path:
57
  model_data["path"] = model_path.replace("info.json", "")
58
  models[model_data["model"]] = model_data
59
 
60
-
61
  model_name = st.selectbox('Which model do you want to use?',
62
  (x for x in sorted(models.keys())),
63
  index=0)
@@ -78,15 +78,28 @@ def get_evaluation_data_from_json(_context: List) -> List[Tuple[List, str, str]]
78
  return output_data
79
 
80
 
81
- def get_evaluation_data_from_dialogue(_context: List) -> List[Tuple[List, str, Union[str, None]]]:
 
 
 
 
 
 
 
 
 
 
 
 
82
  output_data = []
 
 
83
  for idx, _line in enumerate(_context):
84
- if idx == 0:
85
- continue
86
  actual_context = _context[max(0, idx - 5):idx]
87
- actual_sentence = _line
88
  for context_idx in range(len(actual_context)):
89
- output_data.append((actual_context[-context_idx:], actual_sentence, None))
 
90
  return output_data
91
 
92
 
@@ -97,7 +110,6 @@ option = st.selectbox("Choose type of input:",
97
  "04 - JSON (example Elysai)",
98
  "05 - Diagnostic mode"])
99
 
100
-
101
  with st.form("input_text"):
102
  if "01" in option:
103
  context = st.text_area("Insert context here (one turn per line):")
@@ -119,6 +131,7 @@ with st.form("input_text"):
119
  st.pyplot(fig)
120
  elif "02" in option or "03" in option or "04" in option:
121
  from data.example_data import ca_ood, elysai
 
122
  choices = [ca_ood, elysai]
123
  option: str
124
  # > Python 3.10
@@ -163,20 +176,27 @@ with st.form("input_text"):
163
  "Probability (not-follow)"])
164
  st.dataframe(df)
165
  elif "05" in option:
 
166
  context = st.text_area("Insert dialogue here (one turn per line):")
167
  submitted = st.form_submit_button("Submit")
168
  if submitted:
169
- aggregated_result = []
170
- data_for_evaluation = get_evaluation_data_from_dialogue(context.split("\n"))
 
171
  for datapoint in data_for_evaluation:
172
- c, s, _ = datapoint
173
- input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
174
- output_model = model(**input_tensor.data).logits
175
- output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
176
- prop_follow = output_model[0]
177
- prop_not_follow = output_model[1]
 
 
 
178
 
179
- aggregated_result.append((c, s, prop_follow))
 
 
180
  st.table(aggregated_result)
181
 
182
  st.markdown("## Description of models:")
 
1
  import os
2
+ import re
3
+ import numpy as np
4
  import glob
5
  import json
6
  from typing import Dict, List, Tuple, Union
 
10
  import streamlit as st
11
  import matplotlib.pyplot as plt
12
 
 
13
  from inference_tokenizer import NextSentencePredictionTokenizer
14
  from models import get_class
15
+
16
 
17
  def get_model(_model_path):
18
  print(f"Getting model at {_model_path}")
 
58
  model_data["path"] = model_path.replace("info.json", "")
59
  models[model_data["model"]] = model_data
60
 
 
61
  model_name = st.selectbox('Which model do you want to use?',
62
  (x for x in sorted(models.keys())),
63
  index=0)
 
78
  return output_data
79
 
80
 
81
+ control_sequence_regex_1 = re.compile(r"#.*? ")
82
+ control_sequence_regex_2 = re.compile(r"#.*?\n")
83
+
84
+
85
+ def _clean_conversational_line(_line: str):
86
+ _line = _line.replace("Bot: ", "")
87
+ _line = _line.replace("User: ", "")
88
+ _line = control_sequence_regex_1.sub("", _line)
89
+ _line = control_sequence_regex_2.sub("\n", _line)
90
+ return _line.strip()
91
+
92
+
93
+ def get_evaluation_data_from_dialogue(_context: List[str]) -> List[Dict]:
94
  output_data = []
95
+ _context = list(map(lambda x: x.strip(), _context))
96
+ _context = list(filter(lambda x: len(x), _context))
97
  for idx, _line in enumerate(_context):
 
 
98
  actual_context = _context[max(0, idx - 5):idx]
99
+ gradual_context_dict = {_line: []}
100
  for context_idx in range(len(actual_context)):
101
+ gradual_context_dict[_line].append(actual_context[-context_idx:])
102
+ output_data.append(gradual_context_dict)
103
  return output_data
104
 
105
 
 
110
  "04 - JSON (example Elysai)",
111
  "05 - Diagnostic mode"])
112
 
 
113
  with st.form("input_text"):
114
  if "01" in option:
115
  context = st.text_area("Insert context here (one turn per line):")
 
131
  st.pyplot(fig)
132
  elif "02" in option or "03" in option or "04" in option:
133
  from data.example_data import ca_ood, elysai
134
+
135
  choices = [ca_ood, elysai]
136
  option: str
137
  # > Python 3.10
 
176
  "Probability (not-follow)"])
177
  st.dataframe(df)
178
  elif "05" in option:
179
+ context_size = 5
180
  context = st.text_area("Insert dialogue here (one turn per line):")
181
  submitted = st.form_submit_button("Submit")
182
  if submitted:
183
+ data_for_evaluation = get_evaluation_data_from_dialogue(_clean_conversational_line(context).split("\n"))
184
+ lines = []
185
+ scores = np.zeros(shape=(len(data_for_evaluation), context_size))
186
  for datapoint in data_for_evaluation:
187
+ for actual_sentence, contexts in datapoint.items():
188
+ lines.append(actual_sentence)
189
+ for c in contexts:
190
+ input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=actual_sentence)
191
+ output_model = model(**input_tensor.data).logits
192
+ output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
193
+ prop_follow = output_model[0]
194
+ prop_not_follow = output_model[1]
195
+ scores[len(lines) - 1][len(c) - 1] = prop_follow
196
 
197
+ aggregated_result = []
198
+ for idx, line in enumerate(lines):
199
+ aggregated_result.append([line] + scores[idx].tolist())
200
  st.table(aggregated_result)
201
 
202
  st.markdown("## Description of models:")