Pushpankar commited on
Commit
fad545a
·
1 Parent(s): 6c7dea5

Add support for finetuning more than 2 classes

Browse files
Files changed (1) hide show
  1. torchmoji/finetuning.py +21 -9
torchmoji/finetuning.py CHANGED
@@ -15,6 +15,7 @@ import numpy as np
15
  import torch
16
  import torch.nn as nn
17
  import torch.optim as optim
 
18
  from torch.autograd import Variable
19
  from torch.utils.data import Dataset, DataLoader
20
  from torch.utils.data.sampler import BatchSampler, SequentialSampler
@@ -356,15 +357,18 @@ def evaluate_using_acc(model, test_gen):
356
 
357
  # Validate on test_data
358
  model.eval()
359
- correct_count = 0.0
360
- total_y = sum(len(y) for _, y in test_gen)
361
  for i, data in enumerate(test_gen):
362
  x, y = data
363
  outs = model(x)
364
- pred = (outs >= 0).long()
365
- added_counts = (pred == y).double().sum()
366
- correct_count += added_counts
367
- return correct_count/total_y
 
 
 
 
368
 
369
 
370
  def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
@@ -482,6 +486,14 @@ def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs,
482
  if verbose >= 2:
483
  print("Loaded weights from {}".format(checkpoint_path))
484
 
 
 
 
 
 
 
 
 
485
  def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
486
  checkpoint_path, patience):
487
  """ Analog to Keras fit_generator function.
@@ -505,7 +517,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
505
  torch.save(model.state_dict(), checkpoint_path)
506
 
507
  model.eval()
508
- best_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen])
509
  print("original val loss", best_loss)
510
 
511
  epoch_without_impr = 0
@@ -517,7 +529,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
517
  model.train()
518
  optim_op.zero_grad()
519
  output = model(X_train)
520
- loss = loss_op(output, y_train.float())
521
  loss.backward()
522
  clip_grad_norm(model.parameters(), 1)
523
  optim_op.step()
@@ -529,7 +541,7 @@ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
529
  acc = evaluate_using_acc(model, val_gen)
530
  print("val acc", acc)
531
 
532
- val_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen])
533
  print("val loss", val_loss)
534
  if best_loss is not None and val_loss >= best_loss:
535
  epoch_without_impr += 1
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.optim as optim
18
+ from sklearn.metrics import accuracy_score
19
  from torch.autograd import Variable
20
  from torch.utils.data import Dataset, DataLoader
21
  from torch.utils.data.sampler import BatchSampler, SequentialSampler
 
357
 
358
  # Validate on test_data
359
  model.eval()
360
+ accs = []
 
361
  for i, data in enumerate(test_gen):
362
  x, y = data
363
  outs = model(x)
364
+ if model.nb_classes > 2:
365
+ pred = torch.max(outs, 1)[1]
366
+ acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
367
+ else:
368
+ pred = (outs >= 0).long()
369
+ acc = (pred == y).double().sum() / len(pred)
370
+ accs.append(acc)
371
+ return np.mean(accs)
372
 
373
 
374
  def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
 
486
  if verbose >= 2:
487
  print("Loaded weights from {}".format(checkpoint_path))
488
 
489
+
490
+ def calc_loss(loss_op, pred, yv):
491
+ if type(loss_op) is nn.CrossEntropyLoss:
492
+ return loss_op(pred.squeeze(), yv.squeeze())
493
+ else:
494
+ return loss_op(pred.squeeze(), yv.squeeze().float())
495
+
496
+
497
  def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
498
  checkpoint_path, patience):
499
  """ Analog to Keras fit_generator function.
 
517
  torch.save(model.state_dict(), checkpoint_path)
518
 
519
  model.eval()
520
+ best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
521
  print("original val loss", best_loss)
522
 
523
  epoch_without_impr = 0
 
529
  model.train()
530
  optim_op.zero_grad()
531
  output = model(X_train)
532
+ loss = calc_loss(loss_op, output, y_train)
533
  loss.backward()
534
  clip_grad_norm(model.parameters(), 1)
535
  optim_op.step()
 
541
  acc = evaluate_using_acc(model, val_gen)
542
  print("val acc", acc)
543
 
544
+ val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
545
  print("val loss", val_loss)
546
  if best_loss is not None and val_loss >= best_loss:
547
  epoch_without_impr += 1