Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,9 +9,9 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
9 |
@st.cache
|
10 |
def load_models():
|
11 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
12 |
-
bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True)
|
13 |
-
bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True)
|
14 |
-
bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True)
|
15 |
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
|
16 |
|
17 |
|
|
|
9 |
@st.cache
|
10 |
def load_models():
|
11 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
12 |
+
bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
|
13 |
+
bert_mlm_negative = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_negative', return_dict=True).to(device).train(True)
|
14 |
+
bert_classifier = BertForSequenceClassification.from_pretrained('any0019/text_style_classifier', num_labels=2).to(device).train(True)
|
15 |
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
|
16 |
|
17 |
|