Commit
·
06e46d2
1
Parent(s):
026cb58
Update Sentiment_analysis_with_bert.py
Browse files- Sentiment_analysis_with_bert.py +40 -42
Sentiment_analysis_with_bert.py
CHANGED
@@ -303,48 +303,46 @@ def eval_model(model, data_loader, loss_fn, device, n_examples):
|
|
303 |
|
304 |
return correct_predictions.double() / n_examples, np.mean(losses)
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
# torch.save(model.state_dict(), 'best_model_state.bin')
|
347 |
-
# best_accuracy = val_acc
|
348 |
|
349 |
print(history['train_acc'])
|
350 |
|
|
|
303 |
|
304 |
return correct_predictions.double() / n_examples, np.mean(losses)
|
305 |
|
306 |
+
%%time
|
307 |
+
history = defaultdict(list)
|
308 |
+
best_accuracy = 0
|
309 |
+
|
310 |
+
for epoch in range(EPOCHS):
|
311 |
+
|
312 |
+
print(f'Epoch {epoch + 1}/{EPOCHS}')
|
313 |
+
print('-' * 10)
|
314 |
+
|
315 |
+
train_acc, train_loss = train_epoch(
|
316 |
+
model,
|
317 |
+
train_data_loader,
|
318 |
+
loss_fn,
|
319 |
+
optimizer,
|
320 |
+
device,
|
321 |
+
scheduler,
|
322 |
+
len(df_train)
|
323 |
+
)
|
324 |
+
|
325 |
+
print(f'Train loss {train_loss} accuracy {train_acc}')
|
326 |
+
|
327 |
+
val_acc, val_loss = eval_model(
|
328 |
+
model,
|
329 |
+
val_data_loader,
|
330 |
+
loss_fn,
|
331 |
+
device,
|
332 |
+
len(df_val)
|
333 |
+
)
|
334 |
+
|
335 |
+
print(f'Val loss {val_loss} accuracy {val_acc}')
|
336 |
+
print()
|
337 |
+
|
338 |
+
history['train_acc'].append(train_acc)
|
339 |
+
history['train_loss'].append(train_loss)
|
340 |
+
history['val_acc'].append(val_acc)
|
341 |
+
history['val_loss'].append(val_loss)
|
342 |
+
|
343 |
+
if val_acc > best_accuracy:
|
344 |
+
torch.save(model.state_dict(), 'best_model_state.bin')
|
345 |
+
best_accuracy = val_acc
|
|
|
|
|
346 |
|
347 |
print(history['train_acc'])
|
348 |
|