MarieAngeA13 commited on
Commit
06e46d2
·
1 Parent(s): 026cb58

Update Sentiment_analysis_with_bert.py

Browse files
Files changed (1) hide show
  1. Sentiment_analysis_with_bert.py +40 -42
Sentiment_analysis_with_bert.py CHANGED
@@ -303,48 +303,46 @@ def eval_model(model, data_loader, loss_fn, device, n_examples):
303
 
304
  return correct_predictions.double() / n_examples, np.mean(losses)
305
 
306
- # Commented out IPython magic to ensure Python compatibility.
307
- # %%time
308
- #
309
- # history = defaultdict(list)
310
- # best_accuracy = 0
311
- #
312
- # for epoch in range(EPOCHS):
313
- #
314
- # print(f'Epoch {epoch + 1}/{EPOCHS}')
315
- # print('-' * 10)
316
- #
317
- # train_acc, train_loss = train_epoch(
318
- # model,
319
- # train_data_loader,
320
- # loss_fn,
321
- # optimizer,
322
- # device,
323
- # scheduler,
324
- # len(df_train)
325
- # )
326
- #
327
- # print(f'Train loss {train_loss} accuracy {train_acc}')
328
- #
329
- # val_acc, val_loss = eval_model(
330
- # model,
331
- # val_data_loader,
332
- # loss_fn,
333
- # device,
334
- # len(df_val)
335
- # )
336
- #
337
- # print(f'Val loss {val_loss} accuracy {val_acc}')
338
- # print()
339
- #
340
- # history['train_acc'].append(train_acc)
341
- # history['train_loss'].append(train_loss)
342
- # history['val_acc'].append(val_acc)
343
- # history['val_loss'].append(val_loss)
344
- #
345
- # if val_acc > best_accuracy:
346
- # torch.save(model.state_dict(), 'best_model_state.bin')
347
- # best_accuracy = val_acc
348
 
349
  print(history['train_acc'])
350
 
 
303
 
304
  return correct_predictions.double() / n_examples, np.mean(losses)
305
 
306
+ %%time
307
+ history = defaultdict(list)
308
+ best_accuracy = 0
309
+
310
+ for epoch in range(EPOCHS):
311
+
312
+ print(f'Epoch {epoch + 1}/{EPOCHS}')
313
+ print('-' * 10)
314
+
315
+ train_acc, train_loss = train_epoch(
316
+ model,
317
+ train_data_loader,
318
+ loss_fn,
319
+ optimizer,
320
+ device,
321
+ scheduler,
322
+ len(df_train)
323
+ )
324
+
325
+ print(f'Train loss {train_loss} accuracy {train_acc}')
326
+
327
+ val_acc, val_loss = eval_model(
328
+ model,
329
+ val_data_loader,
330
+ loss_fn,
331
+ device,
332
+ len(df_val)
333
+ )
334
+
335
+ print(f'Val loss {val_loss} accuracy {val_acc}')
336
+ print()
337
+
338
+ history['train_acc'].append(train_acc)
339
+ history['train_loss'].append(train_loss)
340
+ history['val_acc'].append(val_acc)
341
+ history['val_loss'].append(val_loss)
342
+
343
+ if val_acc > best_accuracy:
344
+ torch.save(model.state_dict(), 'best_model_state.bin')
345
+ best_accuracy = val_acc
 
 
346
 
347
  print(history['train_acc'])
348