any0019 commited on
Commit
93f4a89
·
1 Parent(s): fb2797d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -9,9 +9,9 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  @st.cache
10
  def load_models():
11
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
- bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True)
13
- bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True)
14
- bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True)
15
  return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
16
 
17
 
 
9
  @st.cache
10
  def load_models():
11
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
+ bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
13
+ bert_mlm_negative = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_negative', return_dict=True).to(device).train(True)
14
+ bert_classifier = BertForSequenceClassification.from_pretrained('any0019/text_style_classifier', num_labels=2).to(device).train(True)
15
  return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
16
 
17