BecomeAllan commited on
Commit
8bf76cf
·
1 Parent(s): 9c0c4aa

update funs

Browse files
Files changed (3) hide show
  1. .vscode/settings.json +7 -0
  2. ML_SLRC.py +382 -44
  3. Util_funs.py +305 -418
.vscode/settings.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "workbench.colorCustomizations": {
3
+ "activityBar.background": "#093518",
4
+ "titleBar.activeBackground": "#0D4A21",
5
+ "titleBar.activeForeground": "#F3FDF6"
6
+ }
7
+ }
ML_SLRC.py CHANGED
@@ -1,33 +1,18 @@
1
- import torch.nn.functional as F
2
- import torch.nn as nn
3
- import math
4
  import torch
5
  import numpy as np
6
- import pandas as pd
7
- import time
8
- import transformers
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
- from sklearn.manifold import TSNE
11
- from copy import deepcopy, copy
12
- import seaborn as sns
13
- import matplotlib.pylab as plt
14
- from pprint import pprint
15
- import shutil
16
- import datetime
17
  import re
18
- import json
19
- from pathlib import Path
20
-
21
- import torch
22
- import torch.nn as nn
23
- from torch.utils.data import Dataset, DataLoader
24
  import unicodedata
25
- import re
26
-
 
 
 
27
  import torch
28
- import torch.nn as nn
29
- from torch.utils.data import Dataset, DataLoader
30
-
31
 
32
 
33
  # Pre-trained model
@@ -117,7 +102,6 @@ class SLR_Classifier(nn.Module):
117
 
118
  return [loss, [feature, logit], predict]
119
 
120
-
121
  # Undesirable patterns within texts
122
  patterns = {
123
  'CONCLUSIONS AND IMPLICATIONS':'',
@@ -157,27 +141,50 @@ patterns = {
157
  '</p>':'',
158
  '<<ETX>>':'',
159
  '+/-':'',
 
 
 
 
 
 
 
 
 
 
 
160
  }
161
 
162
  patterns = {x.lower():y for x,y in patterns.items()}
163
 
164
- LABEL_MAP = {'negative': 0, 'positive': 1}
 
 
 
 
 
 
 
 
 
 
165
 
166
  class SLR_DataSet(Dataset):
167
- def __init__(self, **args):
168
  self.tokenizer = args.get('tokenizer')
169
  self.data = args.get('data')
170
  self.max_seq_length = args.get("max_seq_length", 512)
171
  self.INPUT_NAME = args.get("input", 'x')
172
  self.LABEL_NAME = args.get("output", 'y')
 
173
 
174
  # Tokenizing and processing text
175
  def encode_text(self, example):
176
  comment_text = example[self.INPUT_NAME]
177
- comment_text = self.treat_text(comment_text)
 
178
 
179
  try:
180
- labels = LABEL_MAP[example[self.LABEL_NAME]]
181
  except:
182
  labels = -1
183
 
@@ -200,15 +207,6 @@ class SLR_DataSet(Dataset):
200
  torch.tensor([torch.tensor(labels).to(int)])
201
  ))
202
 
203
- # Text processing function
204
- def treat_text(self, text):
205
- text = unicodedata.normalize("NFKD",str(text))
206
- text = multiple_replace(patterns,text.lower())
207
- text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
208
- text = re.sub('( +)',' ', text)
209
- text = re.sub('(, ,)|(,,)',',', text)
210
- text = re.sub('(%)|(per cent)',' percent', text)
211
- return text
212
 
213
  def __len__(self):
214
  return len(self.data)
@@ -221,6 +219,350 @@ class SLR_DataSet(Dataset):
221
  return temp_data
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  # Regex multiple replace function
226
  def multiple_replace(dict, text):
@@ -229,8 +571,4 @@ def multiple_replace(dict, text):
229
  regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
230
 
231
  # Substitution
232
- return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
233
-
234
- # Undesirable patterns within texts
235
-
236
-
 
1
+ from torch import nn
 
 
2
  import torch
3
  import numpy as np
4
+ from copy import deepcopy
 
 
 
 
 
 
 
 
 
 
5
  import re
 
 
 
 
 
 
6
  import unicodedata
7
+ from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
8
+ from sklearn.model_selection import train_test_split
9
+ from torch.optim import Adam
10
+ from copy import deepcopy
11
+ import gc
12
  import torch
13
+ import numpy as np
14
+ from torchmetrics import functional as fn
15
+ import random
16
 
17
 
18
  # Pre-trained model
 
102
 
103
  return [loss, [feature, logit], predict]
104
 
 
105
  # Undesirable patterns within texts
106
  patterns = {
107
  'CONCLUSIONS AND IMPLICATIONS':'',
 
141
  '</p>':'',
142
  '<<ETX>>':'',
143
  '+/-':'',
144
+ '\(.+\)':'',
145
+ '\[.+\]':'',
146
+ ' \d ':'',
147
+ '<':'',
148
+ '>':'',
149
+ '- ':'',
150
+ ' +':' ',
151
+ ', ,':',',
152
+ ',,':',',
153
+ '%':' percent',
154
+ 'per cent':' percent'
155
  }
156
 
157
  patterns = {x.lower():y for x,y in patterns.items()}
158
 
159
+
160
+ LABEL_MAP = {'negative': 0,
161
+ 'not included':0,
162
+ '0':0,
163
+ 0:0,
164
+ 'excluded':0,
165
+ 'positive': 1,
166
+ 'included':1,
167
+ '1':1,
168
+ 1:1,
169
+ }
170
 
171
  class SLR_DataSet(Dataset):
172
+ def __init__(self,treat_text =None, **args):
173
  self.tokenizer = args.get('tokenizer')
174
  self.data = args.get('data')
175
  self.max_seq_length = args.get("max_seq_length", 512)
176
  self.INPUT_NAME = args.get("input", 'x')
177
  self.LABEL_NAME = args.get("output", 'y')
178
+ self.treat_text = treat_text
179
 
180
  # Tokenizing and processing text
181
  def encode_text(self, example):
182
  comment_text = example[self.INPUT_NAME]
183
+ if self.treat_text:
184
+ comment_text = self.treat_text(comment_text)
185
 
186
  try:
187
+ labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
188
  except:
189
  labels = -1
190
 
 
207
  torch.tensor([torch.tensor(labels).to(int)])
208
  ))
209
 
 
 
 
 
 
 
 
 
 
210
 
211
  def __len__(self):
212
  return len(self.data)
 
219
  return temp_data
220
 
221
 
222
+ class Learner(nn.Module):
223
+
224
+ def __init__(self, **args):
225
+ """
226
+ :param args:
227
+ """
228
+ super(Learner, self).__init__()
229
+
230
+ self.inner_print = args.get('inner_print')
231
+ self.inner_batch_size = args.get('inner_batch_size')
232
+ self.outer_update_lr = args.get('outer_update_lr')
233
+ self.inner_update_lr = args.get('inner_update_lr')
234
+ self.inner_update_step = args.get('inner_update_step')
235
+ self.inner_update_step_eval = args.get('inner_update_step_eval')
236
+ self.model = args.get('model')
237
+ self.device = args.get('device')
238
+
239
+ # Outer optimizer
240
+ self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
241
+ self.model.train()
242
+
243
+ def forward(self, batch_tasks, training = True, valid_train = True):
244
+ """
245
+ batch = [(support TensorDataset, query TensorDataset),
246
+ (support TensorDataset, query TensorDataset),
247
+ (support TensorDataset, query TensorDataset),
248
+ (support TensorDataset, query TensorDataset)]
249
+
250
+ # support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
251
+ """
252
+ task_accs = []
253
+ task_f1 = []
254
+ task_recall = []
255
+ sum_gradients = []
256
+ num_task = len(batch_tasks)
257
+ num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
258
+
259
+ # Outer loop tasks
260
+ for task_id, task in enumerate(batch_tasks):
261
+ support = task[0]
262
+ query = task[1]
263
+ name = task[2]
264
+
265
+ # Copying model
266
+ fast_model = deepcopy(self.model)
267
+ fast_model.to(self.device)
268
+
269
+ # Inner trainer optimizer
270
+ inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
271
+
272
+ # Creating training data loaders
273
+ if len(support) % self.inner_batch_size == 1 :
274
+ support_dataloader = DataLoader(support, sampler=RandomSampler(support),
275
+ batch_size=self.inner_batch_size,
276
+ drop_last=True)
277
+ else:
278
+ support_dataloader = DataLoader(support, sampler=RandomSampler(support),
279
+ batch_size=self.inner_batch_size,
280
+ drop_last=False)
281
+
282
+ # steps_per_epoch=len(support) // self.inner_batch_size
283
+ # total_training_steps = steps_per_epoch * 5
284
+ # warmup_steps = total_training_steps // 3
285
+ #
286
+
287
+ # scheduler = get_linear_schedule_with_warmup(
288
+ # inner_optimizer,
289
+ # num_warmup_steps=warmup_steps,
290
+ # num_training_steps=total_training_steps
291
+ # )
292
+
293
+ fast_model.train()
294
+
295
+ # Inner loop training epoch (support set)
296
+ if valid_train:
297
+ print('----Task',task_id,":", name, '----')
298
+
299
+ for i in range(0, num_inner_update_step):
300
+ all_loss = []
301
+
302
+ # Inner loop training batch (support set)
303
+ for inner_step, batch in enumerate(support_dataloader):
304
+ batch = tuple(t.to(self.device) for t in batch)
305
+ input_ids, attention_mask, token_type_ids, label_id = batch
306
+
307
+ # Feed Foward
308
+ loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
309
+
310
+ # Computing gradients
311
+ loss.backward()
312
+ # torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
313
+
314
+ # Updating inner training parameters
315
+ inner_optimizer.step()
316
+ inner_optimizer.zero_grad()
317
+
318
+ # Appending losses
319
+ all_loss.append(loss.item())
320
+
321
+ del batch, input_ids, attention_mask, label_id
322
+ torch.cuda.empty_cache()
323
+
324
+ if valid_train:
325
+ if (i+1) % self.inner_print == 0:
326
+ print("Inner Loss: ", np.mean(all_loss))
327
+
328
+ fast_model.to(torch.device('cpu'))
329
+
330
+ # Inner training phase weights
331
+ if training:
332
+ meta_weights = list(self.model.parameters())
333
+ fast_weights = list(fast_model.parameters())
334
+
335
+ # Appending gradients
336
+ gradients = []
337
+ for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
338
+ gradient = meta_params - fast_params
339
+ if task_id == 0:
340
+ sum_gradients.append(gradient)
341
+ else:
342
+ sum_gradients[i] += gradient
343
+
344
+
345
+ # Inner test (query set)
346
+ fast_model.to(self.device)
347
+ fast_model.eval()
348
+
349
+ if valid_train:
350
+ # Inner test (query set)
351
+ fast_model.to(self.device)
352
+ fast_model.eval()
353
+
354
+ with torch.no_grad():
355
+ # Data loader
356
+ query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
357
+ query_batch = iter(query_dataloader).next()
358
+ query_batch = tuple(t.to(self.device) for t in query_batch)
359
+ q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
360
+
361
+ # Feedfoward
362
+ _, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
363
+
364
+ # Predictions
365
+ pre_label_id = pre_label_id.detach().cpu().squeeze()
366
+ # Labels
367
+ q_label_id = q_label_id.detach().cpu()
368
+
369
+ # Calculating metrics
370
+ acc = fn.accuracy(pre_label_id, q_label_id).item()
371
+ recall = fn.recall(pre_label_id, q_label_id).item(),
372
+ f1 = fn.f1_score(pre_label_id, q_label_id).item()
373
+
374
+ # appending metrics
375
+ task_accs.append(acc)
376
+ task_f1.append(f1)
377
+ task_recall.append(recall)
378
+
379
+ fast_model.to(torch.device('cpu'))
380
+
381
+ del fast_model, inner_optimizer
382
+ torch.cuda.empty_cache()
383
+
384
+ print("\n")
385
+ print("f1:",np.mean(task_f1))
386
+ print("recall:",np.mean(task_recall))
387
+
388
+ # Updating outer training parameters
389
+ if training:
390
+ # Mean of gradients
391
+ for i in range(0,len(sum_gradients)):
392
+ sum_gradients[i] = sum_gradients[i] / float(num_task)
393
+
394
+ # Indexing parameters to model
395
+ for i, params in enumerate(self.model.parameters()):
396
+ params.grad = sum_gradients[i]
397
+
398
+ # Updating parameters
399
+ self.outer_optimizer.step()
400
+ self.outer_optimizer.zero_grad()
401
+
402
+ del sum_gradients
403
+ gc.collect()
404
+ torch.cuda.empty_cache()
405
+
406
+ if valid_train:
407
+ return np.mean(task_accs)
408
+ else:
409
+ return np.array(0)
410
+
411
+
412
+
413
+ # Creating Meta Tasks
414
+ class MetaTask(Dataset):
415
+ def __init__(self, examples, num_task, k_support, k_query,
416
+ tokenizer, training=True, max_seq_length=512,
417
+ treat_text =None, **args):
418
+ """
419
+ :param samples: list of samples
420
+ :param num_task: number of training tasks.
421
+ :param k_support: number of classes support samples per task
422
+ :param k_query: number of classes query sample per task
423
+ """
424
+ self.examples = examples
425
+
426
+ self.num_task = num_task
427
+ self.k_support = k_support
428
+ self.k_query = k_query
429
+ self.tokenizer = tokenizer
430
+ self.max_seq_length = max_seq_length
431
+ self.treat_text = treat_text
432
+
433
+ # Randomly generating tasks
434
+ self.create_batch(self.num_task, training)
435
+
436
+ # Creating batch
437
+ def create_batch(self, num_task, training):
438
+ self.supports = [] # support set
439
+ self.queries = [] # query set
440
+ self.task_names = [] # Name of task
441
+ self.supports_indexs = [] # index of supports
442
+ self.queries_indexs = [] # index of queries
443
+ self.num_task=num_task
444
+
445
+ # Available tasks
446
+ domains = self.examples['domain'].unique()
447
+
448
+ # If not training, create all tasks
449
+ if not(training):
450
+ self.task_names = domains
451
+ num_task = len(self.task_names)
452
+ self.num_task=num_task
453
+
454
+
455
+ for b in range(num_task): # For each task,
456
+ total_per_class = self.k_support + self.k_query
457
+ task_size = 2*self.k_support + 2*self.k_query
458
+
459
+ # Select a task at random
460
+ if training:
461
+ domain = random.choice(domains)
462
+ self.task_names.append(domain)
463
+ else:
464
+ domain = self.task_names[b]
465
+
466
+ # Task data
467
+ domainExamples = self.examples[self.examples['domain'] == domain]
468
+
469
+ # Minimal label quantity
470
+ min_per_class = min(domainExamples['label'].value_counts())
471
+
472
+ if total_per_class > min_per_class:
473
+ total_per_class = min_per_class
474
+
475
+ # Select k_support + k_query task examples
476
+ # Sample (n) from each label(class)
477
+ selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
478
+
479
+ # Split data into support (training) and query (testing) sets
480
+ s, q = train_test_split(selected_examples,
481
+ stratify= selected_examples["label"],
482
+ test_size= 2*self.k_query/task_size,
483
+ shuffle=True)
484
+
485
+ # Permutating data
486
+ s = s.sample(frac=1)
487
+ q = q.sample(frac=1)
488
+
489
+ # Appending indexes
490
+ if not(training):
491
+ self.supports_indexs.append(s.index)
492
+ self.queries_indexs.append(q.index)
493
+
494
+ # Creating list of support (training) and query (testing) tasks
495
+ self.supports.append(s.to_dict('records'))
496
+ self.queries.append(q.to_dict('records'))
497
+
498
+ # Creating task tensors
499
+ def create_feature_set(self, examples):
500
+ all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
501
+ all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
502
+ all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
503
+ all_label_ids = torch.empty(len(examples), dtype = torch.long)
504
+
505
+ for _id, e in enumerate(examples):
506
+ all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
507
+
508
+ return TensorDataset(
509
+ all_input_ids,
510
+ all_attention_mask,
511
+ all_token_type_ids,
512
+ all_label_ids
513
+ )
514
+
515
+ # Data encoding
516
+ def encode_text(self, example):
517
+ comment_text = example["text"]
518
+
519
+ if self.treat_text:
520
+ comment_text = self.treat_text(comment_text)
521
+
522
+ labels = LABEL_MAP[example["label"]]
523
+
524
+ encoding = self.tokenizer.encode_plus(
525
+ (comment_text, "It is a great text."),
526
+ add_special_tokens=True,
527
+ max_length=self.max_seq_length,
528
+ return_token_type_ids=True,
529
+ padding="max_length",
530
+ truncation=True,
531
+ return_attention_mask=True,
532
+ return_tensors='pt',
533
+ )
534
+
535
+ return tuple((
536
+ encoding["input_ids"].flatten(),
537
+ encoding["attention_mask"].flatten(),
538
+ encoding["token_type_ids"].flatten(),
539
+ torch.tensor([torch.tensor(labels).to(int)])
540
+ ))
541
+
542
+ # Returns data upon calling
543
+ def __getitem__(self, index):
544
+ support_set = self.create_feature_set(self.supports[index])
545
+ query_set = self.create_feature_set(self.queries[index])
546
+ name = self.task_names[index]
547
+ return support_set, query_set, name
548
+
549
+ def __len__(self):
550
+ return self.num_task
551
+
552
+
553
+ class treat_text:
554
+ def __init__(self, patterns):
555
+ self.patterns = patterns
556
+
557
+ def __call__(self,text):
558
+ text = unicodedata.normalize("NFKD",str(text))
559
+ text = multiple_replace(self.patterns,text.lower())
560
+ text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
561
+ text = re.sub('( +)',' ', text)
562
+ text = re.sub('(, ,)|(,,)',',', text)
563
+ text = re.sub('(%)|(per cent)',' percent', text)
564
+ return text
565
+
566
 
567
  # Regex multiple replace function
568
  def multiple_replace(dict, text):
 
571
  regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
572
 
573
  # Substitution
574
+ return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
 
 
 
 
Util_funs.py CHANGED
@@ -1,49 +1,49 @@
 
 
1
  import os
2
- import torch
3
  import numpy as np
 
 
 
 
 
 
 
 
 
4
  import random
5
- import json, pickle
6
 
7
- import torch.nn.functional as F
8
- import torch.nn as nn
9
- import math
 
 
 
 
 
 
 
 
10
  import torch
11
- import numpy as np
12
- import pandas as pd
13
  import time
14
- import transformers
15
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
  from sklearn.manifold import TSNE
17
- from copy import deepcopy, copy
18
  import seaborn as sns
19
  import matplotlib.pylab as plt
20
- from pprint import pprint
21
- import shutil
22
- import datetime
23
- import re
24
  import json
25
  from pathlib import Path
26
- import torch
27
- import torch.nn as nn
28
- from torch.utils.data import Dataset, DataLoader
29
- from torch import nn
30
- from torch.nn import functional as F
31
- from torch.utils.data import TensorDataset, DataLoader, RandomSampler
32
- from torch.optim import Adam
33
- from torch.nn import CrossEntropyLoss
34
- from transformers import BertForSequenceClassification
35
- from copy import deepcopy
36
- import gc
37
- from sklearn.metrics import accuracy_score
38
- import torch
39
- import numpy as np
40
- import torchmetrics
41
- from torchmetrics import functional as fn
42
 
43
 
44
- SEED = 2222
45
 
46
- gen_seed = torch.Generator().manual_seed(SEED)
47
 
48
 
49
  # Random seed function
@@ -54,7 +54,7 @@ def random_seed(value):
54
  np.random.seed(value)
55
  random.seed(value)
56
 
57
- # Batch creation function
58
  def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
59
  idxs = list(range(0,len(taskset)))
60
  if is_shuffle:
@@ -63,48 +63,51 @@ def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
63
  yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
64
 
65
 
66
-
67
- def prepare_data(data, batch_size,tokenizer,max_seq_length,
68
  input = 'text', output = 'label',
69
- train_size_per_class = 5):
 
70
  data = data.reset_index().drop("index", axis=1)
71
 
72
- labaled_data = data.loc[~data['label'].isna()]
73
-
74
- data_train = labaled_data.groupby('label').sample(train_size_per_class)
75
 
76
- rest_labaled_data = labaled_data.loc[~labaled_data.index.isin(data_train.index),:]
77
- unlabaled_data = data.loc[data['label'].isna()]
 
78
 
79
- data_test=data
 
80
 
81
 
82
- # Train
83
- ## Transforma em dataset
84
  dataset_train = SLR_DataSet(
85
  data = data_train.sample(frac=1),
86
  input = input,
87
  output = output,
88
  tokenizer=tokenizer,
89
- max_seq_length =max_seq_length)
 
90
 
91
- # Test
92
- # Dataloaders
93
- ## Transforma em dataset
94
  dataset_test = SLR_DataSet(
95
  data = data_test,
96
  input = input,
97
  output = output,
98
  tokenizer=tokenizer,
99
- max_seq_length =max_seq_length)
 
100
 
101
  # Dataloaders
102
- ## Treino
103
  data_train_loader = DataLoader(dataset_train,
104
  shuffle=True,
105
  batch_size=batch_size['train']
106
  )
107
 
 
108
  if len(dataset_test) % batch_size['test'] == 1 :
109
  data_test_loader = DataLoader(dataset_test,
110
  batch_size=batch_size['test'],
@@ -117,50 +120,54 @@ def prepare_data(data, batch_size,tokenizer,max_seq_length,
117
  return data_train_loader, data_test_loader, data_train, data_test
118
 
119
 
 
 
 
 
 
120
 
121
-
122
-
123
- from tqdm import tqdm
124
-
125
- def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_resource =None):
126
-
127
  learner = Learner(model = model, device = device, **Info)
128
 
129
  # Testing tasks
130
  if isinstance(Test_resource, pd.DataFrame):
131
  test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
132
- training=False, **Info)
133
 
134
 
135
  torch.clear_autocast_cache()
136
  gc.collect()
137
  torch.cuda.empty_cache()
138
 
139
- # Meta epoca
140
  for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
141
- # print("Meta Epoca:", epoch)
142
 
143
- # Tarefas de treino
144
  train = MetaTask(data,
145
  num_task = Info['num_task_train'],
146
  k_support=Info['k_qry'],
147
- k_query=Info['k_spt'], **Info)
 
148
 
149
- # Batchs de tarefas
150
  db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
151
 
152
  if print_epoch:
153
  # Outer loop bach training
154
  for step, task_batch in enumerate(db):
155
  print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
156
- # meta-feedfoward
 
157
  acc = learner(task_batch, valid_train= print_epoch)
158
  print('Step:', step, '\ttraining Acc:', acc)
 
159
  if isinstance(Test_resource, pd.DataFrame):
160
- # Validating Model
161
  if ((epoch+1) % 4) + step == 0:
162
  random_seed(123)
163
  print("\n-----------------Testing Mode-----------------\n")
 
 
164
  db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
165
  acc_all_test = []
166
 
@@ -174,10 +181,10 @@ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_
174
 
175
  # Restarting training randomly
176
  random_seed(int(time.time() % 10))
177
-
178
-
179
  else:
180
  for step, task_batch in enumerate(db):
 
181
  acc = learner(task_batch, print_epoch, valid_train= print_epoch)
182
 
183
  torch.clear_autocast_cache()
@@ -187,14 +194,14 @@ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_
187
 
188
 
189
  def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
190
- # Inicia o modelo
191
  model_meta = deepcopy(model)
192
  optimizer = Adam(model_meta.parameters(), lr=lr)
193
 
194
  model_meta.to(device)
195
  model_meta.train()
196
 
197
- # Loop de treino da tarefa
198
  for i in range(0, epoch):
199
  all_loss = []
200
 
@@ -203,13 +210,13 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
203
  batch = tuple(t.to(device) for t in batch)
204
  input_ids, attention_mask,q_token_type_ids, label_id = batch
205
 
206
- # Feedfoward
207
  loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
208
 
209
- # Calcula gradientes
210
  loss.backward()
211
 
212
- # Atualiza os parametros
213
  optimizer.step()
214
  optimizer.zero_grad()
215
 
@@ -220,39 +227,43 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
220
  print("Loss: ", np.mean(all_loss))
221
 
222
 
223
- # Predicao no banco de teste
224
  model_meta.eval()
225
  all_loss = []
226
- # all_acc = []
227
  features = []
228
  labels = []
229
  predi_logit = []
230
 
231
  with torch.no_grad():
 
232
  for inner_step, batch in enumerate(tqdm(data_test_loader,
233
  desc="Test validation | " + name,
234
  ncols=80)) :
235
  batch = tuple(t.to(device) for t in batch)
236
  input_ids, attention_mask,q_token_type_ids, label_id = batch
237
 
238
- # Predicoes
239
  _, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
240
 
 
241
  prediction = prediction.detach().cpu().squeeze()
242
  label_id = label_id.detach().cpu()
 
 
243
  logit = feature[1].detach().cpu()
244
- feature_lat = feature[0].detach().cpu()
245
 
246
- labels.append(label_id.numpy().squeeze())
247
  features.append(feature_lat.numpy())
248
- predi_logit.append(logit.numpy())
249
 
250
- # acc = fn.accuracy(prediction, label_id).item()
251
- # all_acc.append(acc)
 
252
  del input_ids, attention_mask, label_id, batch
253
 
254
- # if print_info:
255
- # print("acc:", np.mean(all_acc))
256
 
257
  model_meta.to('cpu')
258
  gc.collect()
@@ -260,26 +271,32 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
260
 
261
  del model_meta, optimizer
262
 
 
263
 
 
 
 
264
  features = np.concatenate(np.array(features,dtype=object))
265
- labels = np.concatenate(np.array(labels,dtype=object))
266
- logits = np.concatenate(np.array(predi_logit,dtype=object))
267
-
268
  features = torch.tensor(features.astype(np.float32)).detach().clone()
 
 
269
  labels = torch.tensor(labels.astype(int)).detach().clone()
 
 
270
  logits = torch.tensor(logits.astype(np.float32)).detach().clone()
271
 
272
- # Reducao de dimensionalidade
273
  X_embedded = TSNE(n_components=2, learning_rate='auto',
274
  init='random').fit_transform(features.detach().clone())
275
 
276
  return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
277
-
278
-
279
  def wss_calc(logit, labels, trsh = 0.5):
280
 
281
- # Predicao com base nos treshould
282
  predict_trash = torch.sigmoid(logit).squeeze() >= trsh
 
 
283
  CM = confusion_matrix(labels, predict_trash.to(int) )
284
  tn, fp, fne, tp = CM.ravel()
285
 
@@ -287,36 +304,22 @@ def wss_calc(logit, labels, trsh = 0.5):
287
  N = (tn + fp)
288
  recall = tp/(tp+fne)
289
 
290
- # Wss antigo
291
- wss_old = (tn + fne)/len(labels) -(1- recall)
292
 
293
- # WSS novo
294
- wss_new = (tn/N - fne/P)
295
 
296
  return {
297
- "wss": round(wss_old,4),
298
- "awss": round(wss_new,4),
299
  "R": round(recall,4),
300
  "CM": CM
301
  }
302
 
303
 
304
-
305
-
306
- from sklearn.metrics import confusion_matrix
307
- from torchmetrics import functional as fn
308
- import matplotlib.pyplot as plt
309
- from sklearn.metrics import roc_curve, auc
310
- from sklearn.metrics import roc_auc_score
311
- import ipywidgets as widgets
312
- from IPython.display import HTML, display, clear_output
313
- import matplotlib.pyplot as plt
314
- import seaborn as sns
315
- import warnings
316
-
317
- warnings.simplefilter(action='ignore', category=FutureWarning)
318
-
319
- def plot(logits, X_embedded, labels, tresh, show = True,
320
  namefig = "plot", make_plot = True, print_stats = True, save = True):
321
  col = pd.MultiIndex.from_tuples([
322
  ("Predict", "0"),
@@ -329,30 +332,27 @@ def plot(logits, X_embedded, labels, tresh, show = True,
329
 
330
  predict = torch.sigmoid(logits).detach().clone()
331
 
332
- roc_auc = dict()
333
-
334
  fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
335
 
336
- # Sem especificar o tresh
337
- # WSS
338
- ## indice do recall 0.95
339
  idx_wss95 = sum(tpr < 0.95)
 
340
  thresholds95 = thresholds[idx_wss95]
341
 
 
342
  wss95_info = wss_calc(logits,labels, thresholds95 )
343
  acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
344
  f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
345
 
346
 
347
- # Especificando o tresh
348
- # Treshold avaliation
349
-
350
-
351
- ## WSS
352
- wss_info = wss_calc(logits,labels, tresh )
353
- # Accuraci
354
- acc_wssR = fn.accuracy(predict, labels, threshold=tresh)
355
- f1_wssR = fn.f1_score(predict, labels, threshold=tresh)
356
 
357
 
358
  metrics= {
@@ -370,12 +370,11 @@ def plot(logits, X_embedded, labels, tresh, show = True,
370
  # f1
371
  "f1@95": f1_wss95.item(),
372
  "f1@R": f1_wssR.item(),
373
- # treshould 95
374
- "treshould@95": thresholds95
375
  }
376
 
377
- # print stats
378
-
379
  if print_stats:
380
  wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
381
  wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
@@ -383,14 +382,14 @@ def plot(logits, X_embedded, labels, tresh, show = True,
383
  print(wss95_adj)
384
  print('Acc.:', round(acc_wss95.item(), 4))
385
  print('F1-score:', round(f1_wss95.item(), 4))
386
- print(f"Treshold to wss95: {round(thresholds95, 4)}")
387
  cm = pd.DataFrame(wss95_info['CM'],
388
  index=index,
389
  columns=col)
390
 
391
  print("\nConfusion matrix:")
392
  print(cm)
393
- print("\n---Metrics with threshold:", tresh, "----\n")
394
  wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
395
  print(wss)
396
  wss_adj= f"AWSS@R:{wss_info['awss']}"
@@ -405,51 +404,53 @@ def plot(logits, X_embedded, labels, tresh, show = True,
405
  print(cm)
406
 
407
 
408
- # Graficos
409
 
410
  if make_plot:
411
 
412
  fig, axes = plt.subplots(1, 4, figsize=(25,10))
413
  alpha = torch.squeeze(predict).numpy()
414
 
415
- # plots
416
-
417
  p1 = sns.scatterplot(x=X_embedded[:, 0],
418
  y=X_embedded[:, 1],
419
  hue=labels,
420
- alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE')
421
 
 
 
422
  t_wss = predict >= thresholds95
423
  t_wss = t_wss.squeeze().numpy()
424
-
425
  p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
426
  y=X_embedded[t_wss, 1],
427
  hue=labels[t_wss],
428
- alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95')
429
 
430
- t = predict >= tresh
 
431
  t = t.squeeze().numpy()
432
-
433
  p3 = sns.scatterplot(x=X_embedded[t, 0],
434
  y=X_embedded[t, 1],
435
  hue=labels[t],
436
- alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-Treshold {tresh}')
437
-
438
 
 
439
  roc_auc = auc(fpr, tpr)
440
  lw = 2
441
-
442
  axes[3].plot(
443
  fpr,
444
  tpr,
445
  color="darkorange",
446
  lw=lw,
447
  label="ROC curve (area = %0.2f)" % roc_auc)
448
-
449
  axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
450
  axes[3].axhline(y=0.95, color='r', linestyle='-')
451
- axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate", title= "ROC")
452
  axes[3].legend(loc="lower right")
 
 
 
 
453
 
454
  if show:
455
  plt.show()
@@ -459,6 +460,7 @@ def plot(logits, X_embedded, labels, tresh, show = True,
459
 
460
  return metrics
461
 
 
462
  def auc_plot(logits,labels, color = "darkorange", label = "test"):
463
  predict = torch.sigmoid(logits).detach().clone()
464
  fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
@@ -478,45 +480,40 @@ def auc_plot(logits,labels, color = "darkorange", label = "test"):
478
  plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
479
  plt.axhline(y=0.95, color='r', linestyle='-')
480
 
481
-
482
- from sklearn.metrics import confusion_matrix
483
- from torchmetrics import functional as fn
484
- import matplotlib.pyplot as plt
485
- from sklearn.metrics import roc_curve, auc
486
- from sklearn.metrics import roc_auc_score
487
- import ipywidgets as widgets
488
- from IPython.display import HTML, display, clear_output
489
- import matplotlib.pyplot as plt
490
- import seaborn as sns
491
- import warnings
492
-
493
-
494
  class diagnosis():
495
- def __init__(self, names, Valid_resource, batch_size_test, model,Info,start = 0):
 
496
  self.names=names
497
  self.Valid_resource=Valid_resource
498
  self.batch_size_test=batch_size_test
499
  self.model=model
500
- self.start=start
 
 
 
 
501
 
 
502
  self.value_trash = widgets.FloatText(
503
  value=0.95,
504
- description='tresh',
505
  disabled=False
506
  )
507
-
508
  self.valueb = widgets.IntText(
509
  value=10,
510
  description='size',
511
  disabled=False
512
  )
513
 
 
514
  self.train_b = widgets.Button(description="Train")
515
  self.next_b = widgets.Button(description="Next")
516
  self.eval_b = widgets.Button(description="Evaluation")
517
 
518
  self.hbox = widgets.HBox([self.train_b, self.valueb])
519
 
 
520
  self.next_b.on_click(self.Next_button)
521
  self.train_b.on_click(self.Train_button)
522
  self.eval_b.on_click(self.Evaluation_button)
@@ -527,36 +524,37 @@ class diagnosis():
527
  clear_output()
528
  self.i=self.i+1
529
 
530
- # global domain
531
- self.domain = names[self.i]
532
- print("Name:", self.domain)
533
-
534
- # global data
535
  self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
 
 
536
  print(self.data['label'].value_counts())
537
-
538
  display(self.hbox)
539
  display(self.next_b)
540
 
 
541
  # Train button
542
  def Train_button(self, y):
543
  clear_output()
544
  print(self.domain)
545
 
546
- # Preparing data for training
547
  self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
548
  train_size_per_class = self.valueb.value,
549
- batch_size = {'train': Info['inner_batch_size'],
550
- 'test': batch_size_test},
551
- max_seq_length = Info['max_seq_length'],
552
- tokenizer = Info['tokenizer'],
553
  input = "text",
554
- output = "label")
 
555
 
 
556
  self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
557
- model, device,
558
- epoch = Info['inner_update_step'],
559
- lr=Info['inner_update_lr'],
560
  print_info=True,
561
  name = self.domain)
562
 
@@ -565,6 +563,7 @@ class diagnosis():
565
  display(tresh_box)
566
  display(self.next_b)
567
 
 
568
  # Evaluation button
569
  def Evaluation_button(self, te):
570
  clear_output()
@@ -573,19 +572,18 @@ class diagnosis():
573
  print(self.domain)
574
  # print("\n")
575
  print("-------Train data-------")
576
- print(self.data_train['label'].value_counts())
577
  print("-------Test data-------")
578
- print(self.data_test['label'].value_counts())
579
  # print("\n")
580
 
581
  display(self.next_b)
582
  display(tresh_box)
583
  display(self.hbox)
584
 
585
-
586
  metrics = plot(self.logits, self.X_embedded, self.labels,
587
- tresh=Info['tresh'], show = True,
588
- # namefig= "./"+base_path +"/"+"Results/size_layer/"+ name_domain+'/' +str(n_layers) + '/img/' + str(attempt) + 'plots',
589
  namefig= 'test',
590
  make_plot = True,
591
  print_stats = True,
@@ -593,261 +591,150 @@ class diagnosis():
593
 
594
  def __call__(self):
595
  self.i= self.start-1
596
-
597
  clear_output()
598
  display(self.next_b)
599
 
600
 
601
 
602
 
 
 
 
 
 
 
603
 
 
 
 
 
604
 
 
 
 
 
 
 
 
605
 
606
 
607
-
608
-
609
- import torch.nn.functional as F
610
- import torch.nn as nn
611
- import math
612
- import torch
613
- import numpy as np
614
- import pandas as pd
615
- import time
616
- import transformers
617
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
618
- from sklearn.manifold import TSNE
619
- from copy import deepcopy, copy
620
- import seaborn as sns
621
- import matplotlib.pylab as plt
622
- from pprint import pprint
623
- import shutil
624
- import datetime
625
- import re
626
- import json
627
- from pathlib import Path
628
-
629
- import torch
630
- import torch.nn as nn
631
- from torch.utils.data import Dataset, DataLoader
632
- import unicodedata
633
- import re
634
-
635
- import torch
636
- import torch.nn as nn
637
- from torch.utils.data import Dataset, DataLoader
638
-
639
-
640
-
641
- # Pre-trained model
642
- class Encoder(nn.Module):
643
- def __init__(self, layers, freeze_bert, model):
644
- super(Encoder, self).__init__()
645
-
646
- # Dummy Parameter
647
- self.dummy_param = nn.Parameter(torch.empty(0))
648
-
649
- # Pre-trained model
650
- self.model = deepcopy(model)
651
-
652
- # Freezing bert parameters
653
- if freeze_bert:
654
- for param in self.model.parameters():
655
- param.requires_grad = freeze_bert
656
-
657
- # Selecting hidden layers of the pre-trained model
658
- old_model_encoder = self.model.encoder.layer
659
- new_model_encoder = nn.ModuleList()
660
-
661
- for i in layers:
662
- new_model_encoder.append(old_model_encoder[i])
663
-
664
- self.model.encoder.layer = new_model_encoder
665
 
666
- # Feed forward
667
- def forward(self, **x):
668
- return self.model(**x)['pooler_output']
669
-
670
- # Complete model
671
- class SLR_Classifier(nn.Module):
672
- def __init__(self, **data):
673
- super(SLR_Classifier, self).__init__()
674
-
675
- # Dummy Parameter
676
- self.dummy_param = nn.Parameter(torch.empty(0))
677
-
678
- # Loss function
679
- # Binary Cross Entropy with logits reduced to mean
680
- self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
681
- pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
682
-
683
- # Pre-trained model
684
- self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
685
- freeze_bert = data.get("freeze_bert", False),
686
- model = data.get("model"),
687
- )
688
-
689
- # Feature Map Layer
690
- self.feature_map = nn.Sequential(
691
- # nn.LayerNorm(self.Encoder.model.config.hidden_size),
692
- nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
693
- # nn.Dropout(data.get("drop", 0.5)),
694
- nn.Linear(self.Encoder.model.config.hidden_size, 200),
695
- nn.Dropout(data.get("drop", 0.5)),
696
- )
697
-
698
- # Classifier Layer
699
- self.classifier = nn.Sequential(
700
- # nn.LayerNorm(self.Encoder.model.config.hidden_size),
701
- # nn.Dropout(data.get("drop", 0.5)),
702
- # nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
703
- # nn.Dropout(data.get("drop", 0.5)),
704
- nn.Tanh(),
705
- nn.Linear(200, 1)
706
- )
707
 
708
- # Initializing layer parameters
709
- nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
710
- nn.init.zeros_(self.feature_map[1].bias)
711
-
712
- # Feed forward
713
- def forward(self, input_ids, attention_mask, token_type_ids, labels):
714
-
715
- predict = self.Encoder(**{"input_ids":input_ids,
716
- "attention_mask":attention_mask,
717
- "token_type_ids":token_type_ids})
718
- feature = self.feature_map(predict)
719
- logit = self.classifier(feature)
720
-
721
- predict = torch.sigmoid(logit)
722
 
723
- # Loss function
724
- loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
725
-
726
- return [loss, [feature, logit], predict]
727
-
728
-
729
- # Undesirable patterns within texts
730
- patterns = {
731
- 'CONCLUSIONS AND IMPLICATIONS':'',
732
- 'BACKGROUND AND PURPOSE':'',
733
- 'EXPERIMENTAL APPROACH':'',
734
- 'KEY RESULTS AEA':'',
735
- '©':'',
736
- '®':'',
737
- 'μ':'',
738
- '(C)':'',
739
- 'OBJECTIVE:':'',
740
- 'MATERIALS AND METHODS:':'',
741
- 'SIGNIFICANCE:':'',
742
- 'BACKGROUND:':'',
743
- 'RESULTS:':'',
744
- 'METHODS:':'',
745
- 'CONCLUSIONS:':'',
746
- 'AIM:':'',
747
- 'STUDY DESIGN:':'',
748
- 'CLINICAL RELEVANCE:':'',
749
- 'CONCLUSION:':'',
750
- 'HYPOTHESIS:':'',
751
- 'CLINICAL RELEVANCE:':'',
752
- 'Questions/Purposes:':'',
753
- 'Introduction:':'',
754
- 'PURPOSE:':'',
755
- 'PATIENTS AND METHODS:':'',
756
- 'FINDINGS:':'',
757
- 'INTERPRETATIONS:':'',
758
- 'FUNDING:':'',
759
- 'PROGRESS:':'',
760
- 'CONTEXT:':'',
761
- 'MEASURES:':'',
762
- 'DESIGN:':'',
763
- 'BACKGROUND AND OBJECTIVES:':'',
764
- '<p>':'',
765
- '</p>':'',
766
- '<<ETX>>':'',
767
- '+/-':'',
768
- }
769
-
770
- patterns = {x.lower():y for x,y in patterns.items()}
771
-
772
- LABEL_MAP = {'negative': 0,
773
- 'not included':0,
774
- '0':0,
775
- 0:0,
776
- 'excluded':0,
777
- 'positive': 1,
778
- 'included':1,
779
- '1':1,
780
- 1:1,
781
- }
782
-
783
- class SLR_DataSet(Dataset):
784
- def __init__(self, **args):
785
- self.tokenizer = args.get('tokenizer')
786
- self.data = args.get('data')
787
- self.max_seq_length = args.get("max_seq_length", 512)
788
- self.INPUT_NAME = args.get("input", 'x')
789
- self.LABEL_NAME = args.get("output", 'y')
790
-
791
- # Tokenizing and processing text
792
- def encode_text(self, example):
793
- comment_text = example[self.INPUT_NAME]
794
- comment_text = self.treat_text(comment_text)
795
-
796
- try:
797
- labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
798
- except:
799
- labels = -1
800
-
801
- encoding = self.tokenizer.encode_plus(
802
- (comment_text, "It is great text"),
803
- add_special_tokens=True,
804
- max_length=self.max_seq_length,
805
- return_token_type_ids=True,
806
- padding="max_length",
807
- truncation=True,
808
- return_attention_mask=True,
809
- return_tensors='pt',
810
- )
811
-
812
-
813
- return tuple((
814
- encoding["input_ids"].flatten(),
815
- encoding["attention_mask"].flatten(),
816
- encoding["token_type_ids"].flatten(),
817
- torch.tensor([torch.tensor(labels).to(int)])
818
- ))
819
-
820
- # Text processing function
821
- def treat_text(self, text):
822
- text = unicodedata.normalize("NFKD",str(text))
823
- text = multiple_replace(patterns,text.lower())
824
- text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
825
- text = re.sub('( +)',' ', text)
826
- text = re.sub('(, ,)|(,,)',',', text)
827
- text = re.sub('(%)|(per cent)',' percent', text)
828
- return text
829
-
830
- def __len__(self):
831
- return len(self.data)
832
-
833
- # Returning data
834
- def __getitem__(self, index: int):
835
- # print(index)
836
- data_row = self.data.reset_index().iloc[index]
837
- temp_data = self.encode_text(data_row)
838
- return temp_data
839
-
840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
 
842
- # Regex multiple replace function
843
- def multiple_replace(dict, text):
 
 
 
844
 
845
- # Building regex from dict keys
846
- regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
- # Substitution
849
- return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
 
 
 
850
 
851
- # Undesirable patterns within texts
852
 
853
 
 
1
+ from ML_SLRC import *
2
+
3
  import os
 
4
  import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ from torch.utils.data import DataLoader
9
+ from torch.optim import Adam
10
+
11
+ import gc
12
+ from torchmetrics import functional as fn
13
+
14
  import random
 
15
 
16
+
17
+ warnings.simplefilter(action='ignore', category=FutureWarning)
18
+
19
+ from tqdm import tqdm
20
+
21
+ from sklearn.metrics import confusion_matrix
22
+ from sklearn.metrics import roc_curve, auc
23
+ import ipywidgets as widgets
24
+ from IPython.display import display, clear_output
25
+ import matplotlib.pyplot as plt
26
+ import warnings
27
  import torch
28
+
 
29
  import time
 
 
30
  from sklearn.manifold import TSNE
31
+ from copy import deepcopy
32
  import seaborn as sns
33
  import matplotlib.pylab as plt
 
 
 
 
34
  import json
35
  from pathlib import Path
36
+
37
+ import re
38
+ from collections import defaultdict
39
+
40
+ # SEED = 2222
41
+
42
+ # gen_seed = torch.Generator().manual_seed(SEED)
43
+
 
 
 
 
 
 
 
 
44
 
45
 
 
46
 
 
47
 
48
 
49
  # Random seed function
 
54
  np.random.seed(value)
55
  random.seed(value)
56
 
57
+ # Tasks for meta-learner
58
  def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
59
  idxs = list(range(0,len(taskset)))
60
  if is_shuffle:
 
63
  yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
64
 
65
 
66
+ # Prepare data to process by Domain-learner
67
+ def prepare_data(data, batch_size, tokenizer,max_seq_length,
68
  input = 'text', output = 'label',
69
+ train_size_per_class = 5, global_datasets = False,
70
+ treat_text_fun =None):
71
  data = data.reset_index().drop("index", axis=1)
72
 
73
+ if global_datasets:
74
+ global data_train, data_test
 
75
 
76
+ # Sample task for training
77
+ data_train = data.groupby('label').sample(train_size_per_class, replace=False)
78
+ idex = data.index.isin(data_train.index)
79
 
80
+ # The Test set to label by the model
81
+ data_test = data[~idex].reset_index()
82
 
83
 
84
+ # Transform in dataset to model
85
+ ## Train
86
  dataset_train = SLR_DataSet(
87
  data = data_train.sample(frac=1),
88
  input = input,
89
  output = output,
90
  tokenizer=tokenizer,
91
+ max_seq_length =max_seq_length,
92
+ treat_text =treat_text_fun)
93
 
94
+ ## Test
 
 
95
  dataset_test = SLR_DataSet(
96
  data = data_test,
97
  input = input,
98
  output = output,
99
  tokenizer=tokenizer,
100
+ max_seq_length =max_seq_length,
101
+ treat_text =treat_text_fun)
102
 
103
  # Dataloaders
104
+ ## Train
105
  data_train_loader = DataLoader(dataset_train,
106
  shuffle=True,
107
  batch_size=batch_size['train']
108
  )
109
 
110
+ ## Test
111
  if len(dataset_test) % batch_size['test'] == 1 :
112
  data_test_loader = DataLoader(dataset_test,
113
  batch_size=batch_size['test'],
 
120
  return data_train_loader, data_test_loader, data_train, data_test
121
 
122
 
123
+ # Meta trainer
124
+ def meta_train(data, model, device, Info,
125
+ print_epoch =True,
126
+ Test_resource =None,
127
+ treat_text_fun =None):
128
 
129
+ # Meta-learner model
 
 
 
 
 
130
  learner = Learner(model = model, device = device, **Info)
131
 
132
  # Testing tasks
133
  if isinstance(Test_resource, pd.DataFrame):
134
  test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
135
+ training=False,treat_text =treat_text_fun, **Info)
136
 
137
 
138
  torch.clear_autocast_cache()
139
  gc.collect()
140
  torch.cuda.empty_cache()
141
 
142
+ # Meta epoch (Outer epoch)
143
  for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
 
144
 
145
+ # Train tasks
146
  train = MetaTask(data,
147
  num_task = Info['num_task_train'],
148
  k_support=Info['k_qry'],
149
+ k_query=Info['k_spt'],
150
+ treat_text =treat_text_fun, **Info)
151
 
152
+ # Batch of train tasks
153
  db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
154
 
155
  if print_epoch:
156
  # Outer loop bach training
157
  for step, task_batch in enumerate(db):
158
  print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
159
+
160
+ # meta-feedfoward (outer-feedfoward)
161
  acc = learner(task_batch, valid_train= print_epoch)
162
  print('Step:', step, '\ttraining Acc:', acc)
163
+
164
  if isinstance(Test_resource, pd.DataFrame):
165
+ # Validating Model
166
  if ((epoch+1) % 4) + step == 0:
167
  random_seed(123)
168
  print("\n-----------------Testing Mode-----------------\n")
169
+
170
+ # Batch of test tasks
171
  db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
172
  acc_all_test = []
173
 
 
181
 
182
  # Restarting training randomly
183
  random_seed(int(time.time() % 10))
184
+
 
185
  else:
186
  for step, task_batch in enumerate(db):
187
+ # meta-feedfoward (outer-feedfoward)
188
  acc = learner(task_batch, print_epoch, valid_train= print_epoch)
189
 
190
  torch.clear_autocast_cache()
 
194
 
195
 
196
  def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
197
+ # Start the model's parameters
198
  model_meta = deepcopy(model)
199
  optimizer = Adam(model_meta.parameters(), lr=lr)
200
 
201
  model_meta.to(device)
202
  model_meta.train()
203
 
204
+ # Task epoch (Inner epoch)
205
  for i in range(0, epoch):
206
  all_loss = []
207
 
 
210
  batch = tuple(t.to(device) for t in batch)
211
  input_ids, attention_mask,q_token_type_ids, label_id = batch
212
 
213
+ # Inner Feedfoward
214
  loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
215
 
216
+ # compute grads
217
  loss.backward()
218
 
219
+ # update parameters
220
  optimizer.step()
221
  optimizer.zero_grad()
222
 
 
227
  print("Loss: ", np.mean(all_loss))
228
 
229
 
230
+ # Test evaluation
231
  model_meta.eval()
232
  all_loss = []
233
+ all_acc = []
234
  features = []
235
  labels = []
236
  predi_logit = []
237
 
238
  with torch.no_grad():
239
+ # Test's Batch loop
240
  for inner_step, batch in enumerate(tqdm(data_test_loader,
241
  desc="Test validation | " + name,
242
  ncols=80)) :
243
  batch = tuple(t.to(device) for t in batch)
244
  input_ids, attention_mask,q_token_type_ids, label_id = batch
245
 
246
+ # Predictions
247
  _, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
248
 
249
+ # Save batch's predictions
250
  prediction = prediction.detach().cpu().squeeze()
251
  label_id = label_id.detach().cpu()
252
+ labels.append(label_id.numpy().squeeze())
253
+
254
  logit = feature[1].detach().cpu()
255
+ predi_logit.append(logit.numpy())
256
 
257
+ feature_lat = feature[0].detach().cpu()
258
  features.append(feature_lat.numpy())
 
259
 
260
+ # Accuracy over the test's bach
261
+ acc = fn.accuracy(prediction, label_id).item()
262
+ all_acc.append(acc)
263
  del input_ids, attention_mask, label_id, batch
264
 
265
+ if print_info:
266
+ print("acc:", np.mean(all_acc))
267
 
268
  model_meta.to('cpu')
269
  gc.collect()
 
271
 
272
  del model_meta, optimizer
273
 
274
+ return map_feature_tsne(features, labels, predi_logit)
275
 
276
+ # Process predictions and map the feature_map in tsne
277
+ def map_feature_tsne(features, labels, predi_logit):
278
+
279
  features = np.concatenate(np.array(features,dtype=object))
 
 
 
280
  features = torch.tensor(features.astype(np.float32)).detach().clone()
281
+
282
+ labels = np.concatenate(np.array(labels,dtype=object))
283
  labels = torch.tensor(labels.astype(int)).detach().clone()
284
+
285
+ logits = np.concatenate(np.array(predi_logit,dtype=object))
286
  logits = torch.tensor(logits.astype(np.float32)).detach().clone()
287
 
288
+ # Dimention reduction
289
  X_embedded = TSNE(n_components=2, learning_rate='auto',
290
  init='random').fit_transform(features.detach().clone())
291
 
292
  return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
293
+
 
294
  def wss_calc(logit, labels, trsh = 0.5):
295
 
296
+ # Prediction label given the threshold
297
  predict_trash = torch.sigmoid(logit).squeeze() >= trsh
298
+
299
+ # Compute confusion matrix values
300
  CM = confusion_matrix(labels, predict_trash.to(int) )
301
  tn, fp, fne, tp = CM.ravel()
302
 
 
304
  N = (tn + fp)
305
  recall = tp/(tp+fne)
306
 
307
+ # WSS
308
+ wss = (tn + fne)/len(labels) -(1- recall)
309
 
310
+ # AWSS
311
+ awss = (tn/N - fne/P)
312
 
313
  return {
314
+ "wss": round(wss,4),
315
+ "awss": round(awss,4),
316
  "R": round(recall,4),
317
  "CM": CM
318
  }
319
 
320
 
321
+ # Compute the metrics
322
+ def plot(logits, X_embedded, labels, threshold, show = True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  namefig = "plot", make_plot = True, print_stats = True, save = True):
324
  col = pd.MultiIndex.from_tuples([
325
  ("Predict", "0"),
 
332
 
333
  predict = torch.sigmoid(logits).detach().clone()
334
 
335
+ # Roc curve
 
336
  fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
337
 
338
+ # Given by a Recall of 95% (threshold avaliation)
339
+ ## WSS
340
+ ### Index to recall
341
  idx_wss95 = sum(tpr < 0.95)
342
+ ### threshold
343
  thresholds95 = thresholds[idx_wss95]
344
 
345
+ ### Compute the metrics
346
  wss95_info = wss_calc(logits,labels, thresholds95 )
347
  acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
348
  f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
349
 
350
 
351
+ # Given by a threshold (recall avaliation)
352
+ ### Compute the metrics
353
+ wss_info = wss_calc(logits,labels, threshold )
354
+ acc_wssR = fn.accuracy(predict, labels, threshold=threshold)
355
+ f1_wssR = fn.f1_score(predict, labels, threshold=threshold)
 
 
 
 
356
 
357
 
358
  metrics= {
 
370
  # f1
371
  "f1@95": f1_wss95.item(),
372
  "f1@R": f1_wssR.item(),
373
+ # threshold 95
374
+ "threshold@95": thresholds95
375
  }
376
 
377
+ # Print stats
 
378
  if print_stats:
379
  wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
380
  wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
 
382
  print(wss95_adj)
383
  print('Acc.:', round(acc_wss95.item(), 4))
384
  print('F1-score:', round(f1_wss95.item(), 4))
385
+ print(f"threshold to wss95: {round(thresholds95, 4)}")
386
  cm = pd.DataFrame(wss95_info['CM'],
387
  index=index,
388
  columns=col)
389
 
390
  print("\nConfusion matrix:")
391
  print(cm)
392
+ print("\n---Metrics with threshold:", threshold, "----\n")
393
  wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
394
  print(wss)
395
  wss_adj= f"AWSS@R:{wss_info['awss']}"
 
404
  print(cm)
405
 
406
 
407
+ # Plots
408
 
409
  if make_plot:
410
 
411
  fig, axes = plt.subplots(1, 4, figsize=(25,10))
412
  alpha = torch.squeeze(predict).numpy()
413
 
414
+ # TSNE
 
415
  p1 = sns.scatterplot(x=X_embedded[:, 0],
416
  y=X_embedded[:, 1],
417
  hue=labels,
418
+ alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20)
419
 
420
+
421
+ # WSS@95
422
  t_wss = predict >= thresholds95
423
  t_wss = t_wss.squeeze().numpy()
 
424
  p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
425
  y=X_embedded[t_wss, 1],
426
  hue=labels[t_wss],
427
+ alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20)
428
 
429
+ # WSS@R
430
+ t = predict >= threshold
431
  t = t.squeeze().numpy()
 
432
  p3 = sns.scatterplot(x=X_embedded[t, 0],
433
  y=X_embedded[t, 1],
434
  hue=labels[t],
435
+ alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20)
 
436
 
437
+ # ROC-Curve
438
  roc_auc = auc(fpr, tpr)
439
  lw = 2
 
440
  axes[3].plot(
441
  fpr,
442
  tpr,
443
  color="darkorange",
444
  lw=lw,
445
  label="ROC curve (area = %0.2f)" % roc_auc)
 
446
  axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
447
  axes[3].axhline(y=0.95, color='r', linestyle='-')
448
+ # axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate")
449
  axes[3].legend(loc="lower right")
450
+ axes[3].set_title(label= "ROC", size = 20)
451
+ axes[3].set_ylabel("True Positive Rate", fontsize = 15)
452
+ axes[3].set_xlabel("False Positive Rate", fontsize = 15)
453
+
454
 
455
  if show:
456
  plt.show()
 
460
 
461
  return metrics
462
 
463
+
464
  def auc_plot(logits,labels, color = "darkorange", label = "test"):
465
  predict = torch.sigmoid(logits).detach().clone()
466
  fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
 
480
  plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
481
  plt.axhline(y=0.95, color='r', linestyle='-')
482
 
483
+ # Interface to evaluation
 
 
 
 
 
 
 
 
 
 
 
 
484
  class diagnosis():
485
+ def __init__(self, names, Valid_resource, batch_size_test,
486
+ model,Info, device,treat_text_fun=None,start = 0):
487
  self.names=names
488
  self.Valid_resource=Valid_resource
489
  self.batch_size_test=batch_size_test
490
  self.model=model
491
+ self.start=start
492
+ self.Info = Info
493
+ self.device = device
494
+ self.treat_text_fun = treat_text_fun
495
+
496
 
497
+ # BOX INPUT
498
  self.value_trash = widgets.FloatText(
499
  value=0.95,
500
+ description='threshold',
501
  disabled=False
502
  )
 
503
  self.valueb = widgets.IntText(
504
  value=10,
505
  description='size',
506
  disabled=False
507
  )
508
 
509
+ # Buttons
510
  self.train_b = widgets.Button(description="Train")
511
  self.next_b = widgets.Button(description="Next")
512
  self.eval_b = widgets.Button(description="Evaluation")
513
 
514
  self.hbox = widgets.HBox([self.train_b, self.valueb])
515
 
516
+ # Click buttons functions
517
  self.next_b.on_click(self.Next_button)
518
  self.train_b.on_click(self.Train_button)
519
  self.eval_b.on_click(self.Evaluation_button)
 
524
  clear_output()
525
  self.i=self.i+1
526
 
527
+ # Select the domain data
528
+ self.domain = self.names[self.i]
 
 
 
529
  self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
530
+
531
+ print("Name:", self.domain)
532
  print(self.data['label'].value_counts())
 
533
  display(self.hbox)
534
  display(self.next_b)
535
 
536
+
537
  # Train button
538
  def Train_button(self, y):
539
  clear_output()
540
  print(self.domain)
541
 
542
+ # Prepare data for training (domain-learner)
543
  self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
544
  train_size_per_class = self.valueb.value,
545
+ batch_size = {'train': self.Info['inner_batch_size'],
546
+ 'test': self.batch_size_test},
547
+ max_seq_length = self.Info['max_seq_length'],
548
+ tokenizer = self.Info['tokenizer'],
549
  input = "text",
550
+ output = "label",
551
+ treat_text_fun=self.treat_text_fun)
552
 
553
+ # Train the model and predict in the test set
554
  self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
555
+ self.model, self.device,
556
+ epoch = self.Info['inner_update_step'],
557
+ lr=self.Info['inner_update_lr'],
558
  print_info=True,
559
  name = self.domain)
560
 
 
563
  display(tresh_box)
564
  display(self.next_b)
565
 
566
+
567
  # Evaluation button
568
  def Evaluation_button(self, te):
569
  clear_output()
 
572
  print(self.domain)
573
  # print("\n")
574
  print("-------Train data-------")
575
+ print(data_train['label'].value_counts())
576
  print("-------Test data-------")
577
+ print(data_test['label'].value_counts())
578
  # print("\n")
579
 
580
  display(self.next_b)
581
  display(tresh_box)
582
  display(self.hbox)
583
 
584
+ # Compute metrics
585
  metrics = plot(self.logits, self.X_embedded, self.labels,
586
+ threshold=self.Info['threshold'], show = True,
 
587
  namefig= 'test',
588
  make_plot = True,
589
  print_stats = True,
 
591
 
592
  def __call__(self):
593
  self.i= self.start-1
 
594
  clear_output()
595
  display(self.next_b)
596
 
597
 
598
 
599
 
600
+ # Simulation attemps of domain learner
601
+ def pipeline_simulation(Valid_resource, names_to_valid, path_save,
602
+ model, Info, device, initializer_model,
603
+ treat_text_fun=None):
604
+ n_attempt = 5
605
+ batch_test = 100
606
 
607
+ # Create a directory to save informations
608
+ for name in names_to_valid:
609
+ name = re.sub("\.csv", "",name)
610
+ Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True)
611
 
612
+ # Dict to sabe roc curves
613
+ roc_stats = defaultdict(lambda: defaultdict(
614
+ lambda: defaultdict(
615
+ list
616
+ )
617
+ )
618
+ )
619
 
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
+ all_metrics = []
624
+ # Loop over a list of domains
625
+ for name in names_to_valid:
 
 
 
 
 
 
 
 
 
 
 
626
 
627
+ # Select a domain dataset
628
+ data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
+ # Attempts simulation
631
+ for attempt in range(n_attempt):
632
+ print("---"*4,"attempt", attempt, "---"*4)
633
+
634
+ # Prepare data to pass to the model
635
+ data_train_loader, data_test_loader, _ , _ = prepare_data(data,
636
+ train_size_per_class = Info['k_spt'],
637
+ batch_size = {'train': Info['inner_batch_size'],
638
+ 'test': batch_test},
639
+ max_seq_length = Info['max_seq_length'],
640
+ tokenizer = Info['tokenizer'],
641
+ input = "text",
642
+ output = "label",
643
+ treat_text_fun=treat_text_fun)
644
+
645
+ # Train the model and evaluate on the test set of the domain
646
+ logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader,
647
+ model, device,
648
+ epoch = Info['inner_update_step'],
649
+ lr=Info['inner_update_lr'],
650
+ print_info=False,
651
+ name = name)
652
+
653
+
654
+ name_domain = re.sub("\.csv", "",name)
655
 
656
+ # Compute the metrics
657
+ metrics = plot(logits, X_embedded, labels,
658
+ threshold=Info['threshold'], show = False,
659
+ namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots',
660
+ make_plot = True, print_stats = False, save = True)
661
 
662
+ # Compute the roc-curve
663
+ fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze())
664
+
665
+ # Save the correspoud information of the domain
666
+ metrics['name'] = name_domain
667
+ metrics['layer_size'] = Info['bert_layers']
668
+ metrics['attempt'] = attempt
669
+ roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist())
670
+ roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist())
671
+ all_metrics.append(metrics)
672
+
673
+ # Save the metrics and the roc curve of the attemp
674
+ pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv")
675
+ roc_path = path_save + "roc_stats.json"
676
+ with open(roc_path, 'w') as fp:
677
+ json.dump(roc_stats, fp)
678
+
679
+
680
+ del fpr, tpr, logits, X_embedded, labels
681
+ del features, metrics, _
682
+
683
+
684
+ # Save the information used to evaluate the validation resource
685
+ save_info = Info.copy()
686
+ save_info['model'] = initializer_model.tokenizer.name_or_path
687
+ save_info.pop("tokenizer")
688
+ save_info.pop("bert_layers")
689
+
690
+ info_path = path_save+"info.json"
691
+ with open(info_path, 'w') as fp:
692
+ json.dump(save_info, fp)
693
+
694
+
695
+ # Loading dataset statistics
696
+ def load_data_statistics(paths, names):
697
+ size = []
698
+ pos = []
699
+ neg = []
700
+ for p in paths:
701
+ data = pd.read_csv(p)
702
+ data = data.dropna()
703
+ # Dataset size
704
+ size.append(len(data))
705
+ # Number of positive labels
706
+ pos.append(data['labels'].value_counts()[1])
707
+ # Number of negative labels
708
+ neg.append(data['labels'].value_counts()[0])
709
+ del data
710
+
711
+ info_load = pd.DataFrame({
712
+ "size":size,
713
+ "pos":pos,
714
+ "neg":neg,
715
+ "names":names,
716
+ "paths": paths })
717
+ return info_load
718
+
719
+ # Loading the datasets
720
+ def load_data(train_info_load):
721
+
722
+ col = ['abstract','title', 'labels', 'domain']
723
+
724
+ data_train = pd.DataFrame(columns=col)
725
+ for p in train_info_load['paths']:
726
+ data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
727
+ data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
728
+ data_temp['domain'] = os.path.basename(p)
729
+ data_train = pd.concat([data_train, data_temp])
730
+
731
+ data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '')
732
 
733
+ return( data_train \
734
+ .replace({"labels":{0:"negative", 1:'positive'}})\
735
+ .rename({"labels":"label"} , axis=1)\
736
+ .loc[ :,("text","domain","label")]
737
+ )
738
 
 
739
 
740