any0019 commited on
Commit
5fa5cad
·
1 Parent(s): b086847

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -6,7 +6,7 @@ from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassifi
6
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
 
9
- @st.cache
10
  def load_models():
11
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
  bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
@@ -90,14 +90,16 @@ def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_out
90
  used_poss = []
91
 
92
  st.write(f"step #0:")
93
- st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})\n {sentence}")
 
94
 
95
  for step in range(max_steps):
96
  current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
97
 
98
  st.write(f"\nstep #{step+1}:")
99
  for i, (sent, prob) in enumerate(current_beam.items()):
100
- st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})\n {highlight_diff(sent, sentence) if pretty_output else sent}")
 
101
 
102
  if used_pos is None:
103
  return current_beam, used_poss
 
6
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
 
9
+ @st.cache(allow_output_mutation=True)
10
  def load_models():
11
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
  bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
 
90
  used_poss = []
91
 
92
  st.write(f"step #0:")
93
+ st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})"
94
+ st.write(f" {sentence}")
95
 
96
  for step in range(max_steps):
97
  current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
98
 
99
  st.write(f"\nstep #{step+1}:")
100
  for i, (sent, prob) in enumerate(current_beam.items()):
101
+ st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})"
102
+ st.write(f" {highlight_diff(sent, sentence) if pretty_output else sent}")
103
 
104
  if used_pos is None:
105
  return current_beam, used_poss