alisrbdni commited on
Commit
6019563
·
verified ·
1 Parent(s): 92bae51

adding model_name hubert

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -308,12 +308,12 @@ def test(net, testloader):
308
  accuracy = metric.compute()["accuracy"]
309
  return loss, accuracy
310
 
311
- net = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(DEVICE)
312
 
313
  def main():
314
  st.write("## Federated Learning with dynamic models and datasets for mobile devices")
315
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
316
- model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
317
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
318
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
319
 
 
308
  accuracy = metric.compute()["accuracy"]
309
  return loss, accuracy
310
 
311
+ net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
312
 
313
  def main():
314
  st.write("## Federated Learning with dynamic models and datasets for mobile devices")
315
  dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
316
+ model_name = st.selectbox("Model", ["facebook/hubert-base-ls960","bert-base-uncased", "distilbert-base-uncased"])
317
  NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
318
  NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
319