Pushpankar
commited on
Commit
·
fad545a
1
Parent(s):
6c7dea5
Add support for finetuning more than 2 classes
Browse files- torchmoji/finetuning.py +21 -9
torchmoji/finetuning.py
CHANGED
@@ -15,6 +15,7 @@ import numpy as np
|
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.optim as optim
|
|
|
18 |
from torch.autograd import Variable
|
19 |
from torch.utils.data import Dataset, DataLoader
|
20 |
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
@@ -356,15 +357,18 @@ def evaluate_using_acc(model, test_gen):
|
|
356 |
|
357 |
# Validate on test_data
|
358 |
model.eval()
|
359 |
-
|
360 |
-
total_y = sum(len(y) for _, y in test_gen)
|
361 |
for i, data in enumerate(test_gen):
|
362 |
x, y = data
|
363 |
outs = model(x)
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
368 |
|
369 |
|
370 |
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
@@ -482,6 +486,14 @@ def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs,
|
|
482 |
if verbose >= 2:
|
483 |
print("Loaded weights from {}".format(checkpoint_path))
|
484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
486 |
checkpoint_path, patience):
|
487 |
""" Analog to Keras fit_generator function.
|
@@ -505,7 +517,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
|
505 |
torch.save(model.state_dict(), checkpoint_path)
|
506 |
|
507 |
model.eval()
|
508 |
-
best_loss = np.mean([loss_op
|
509 |
print("original val loss", best_loss)
|
510 |
|
511 |
epoch_without_impr = 0
|
@@ -517,7 +529,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
|
517 |
model.train()
|
518 |
optim_op.zero_grad()
|
519 |
output = model(X_train)
|
520 |
-
loss = loss_op
|
521 |
loss.backward()
|
522 |
clip_grad_norm(model.parameters(), 1)
|
523 |
optim_op.step()
|
@@ -529,7 +541,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
|
529 |
acc = evaluate_using_acc(model, val_gen)
|
530 |
print("val acc", acc)
|
531 |
|
532 |
-
val_loss = np.mean([loss_op
|
533 |
print("val loss", val_loss)
|
534 |
if best_loss is not None and val_loss >= best_loss:
|
535 |
epoch_without_impr += 1
|
|
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.optim as optim
|
18 |
+
from sklearn.metrics import accuracy_score
|
19 |
from torch.autograd import Variable
|
20 |
from torch.utils.data import Dataset, DataLoader
|
21 |
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
|
|
357 |
|
358 |
# Validate on test_data
|
359 |
model.eval()
|
360 |
+
accs = []
|
|
|
361 |
for i, data in enumerate(test_gen):
|
362 |
x, y = data
|
363 |
outs = model(x)
|
364 |
+
if model.nb_classes > 2:
|
365 |
+
pred = torch.max(outs, 1)[1]
|
366 |
+
acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
|
367 |
+
else:
|
368 |
+
pred = (outs >= 0).long()
|
369 |
+
acc = (pred == y).double().sum() / len(pred)
|
370 |
+
accs.append(acc)
|
371 |
+
return np.mean(accs)
|
372 |
|
373 |
|
374 |
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
|
|
486 |
if verbose >= 2:
|
487 |
print("Loaded weights from {}".format(checkpoint_path))
|
488 |
|
489 |
+
|
490 |
+
def calc_loss(loss_op, pred, yv):
|
491 |
+
if type(loss_op) is nn.CrossEntropyLoss:
|
492 |
+
return loss_op(pred.squeeze(), yv.squeeze())
|
493 |
+
else:
|
494 |
+
return loss_op(pred.squeeze(), yv.squeeze().float())
|
495 |
+
|
496 |
+
|
497 |
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
498 |
checkpoint_path, patience):
|
499 |
""" Analog to Keras fit_generator function.
|
|
|
517 |
torch.save(model.state_dict(), checkpoint_path)
|
518 |
|
519 |
model.eval()
|
520 |
+
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
|
521 |
print("original val loss", best_loss)
|
522 |
|
523 |
epoch_without_impr = 0
|
|
|
529 |
model.train()
|
530 |
optim_op.zero_grad()
|
531 |
output = model(X_train)
|
532 |
+
loss = calc_loss(loss_op, output, y_train)
|
533 |
loss.backward()
|
534 |
clip_grad_norm(model.parameters(), 1)
|
535 |
optim_op.step()
|
|
|
541 |
acc = evaluate_using_acc(model, val_gen)
|
542 |
print("val acc", acc)
|
543 |
|
544 |
+
val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
|
545 |
print("val loss", val_loss)
|
546 |
if best_loss is not None and val_loss >= best_loss:
|
547 |
epoch_without_impr += 1
|