File size: 47,042 Bytes
6c0ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
# -*- coding: utf-8 -*-
# @Time    : 2022/4/21 5:30 下午
# @Author  : JianingWang
# @File    : span_proto.py

"""
This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
"""

import os
from typing import Optional
import torch
import numpy as np
import torch.nn as nn
from typing import Union
from dataclasses import dataclass
from torch.nn import BCEWithLogitsLoss
from transformers import MegatronBertModel, MegatronBertPreTrainedModel
from transformers.file_utils import ModelOutput
from transformers.models.bert import BertPreTrainedModel, BertModel

a = torch.nn.Embedding(10, 20)
a.parameters

class RawGlobalPointer(nn.Module):
    def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
        # encodr: RoBerta-Large as encoder
        # inner_dim: 64
        # ent_type_size: ent_cls_num
        super().__init__()
        self.encoder = encoder
        self.ent_type_size = ent_type_size
        self.inner_dim = inner_dim
        self.hidden_size = encoder.config.hidden_size
        self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)

        self.RoPE = RoPE

    def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
        position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)

        indices = torch.arange(0, output_dim // 2, dtype=torch.float)
        indices = torch.pow(10000, -2 * indices / output_dim)
        embeddings = position_ids * indices
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
        embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
        embeddings = embeddings.to(self.device)
        return embeddings

    def forward(self, input_ids, attention_mask, token_type_ids):
        self.device = input_ids.device

        context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
        # last_hidden_state:(batch_size, seq_len, hidden_size)
        last_hidden_state = context_outputs[0]

        batch_size = last_hidden_state.size()[0]
        seq_len = last_hidden_state.size()[1]

        outputs = self.dense(last_hidden_state)
        outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)
        outputs = torch.stack(outputs, dim=-2)
        qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]
        if self.RoPE:
            # pos_emb:(batch_size, seq_len, inner_dim)
            pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
            cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
            sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1)
            qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
            qw2 = qw2.reshape(qw.shape)
            qw = qw * cos_pos + qw2 * sin_pos
            kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1)
            kw2 = kw2.reshape(kw.shape)
            kw = kw * cos_pos + kw2 * sin_pos
        # logits:(batch_size, ent_type_size, seq_len, seq_len)
        logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw)

        # padding mask
        pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len)
        logits = logits * pad_mask - (1 - pad_mask) * 1e12

        # 排除下三角
        mask = torch.tril(torch.ones_like(logits), -1)
        logits = logits - mask * 1e12

        return logits / self.inner_dim ** 0.5


class SinusoidalPositionEmbedding(nn.Module):
    """定义Sin-Cos位置Embedding
    """

    def __init__(
            self, output_dim, merge_mode="add", custom_position_ids=False):
        super(SinusoidalPositionEmbedding, self).__init__()
        self.output_dim = output_dim
        self.merge_mode = merge_mode
        self.custom_position_ids = custom_position_ids

    def forward(self, inputs):
        if self.custom_position_ids:
            seq_len = inputs.shape[1]
            inputs, position_ids = inputs
            position_ids = position_ids.type(torch.float)
        else:
            input_shape = inputs.shape
            batch_size, seq_len = input_shape[0], input_shape[1]
            position_ids = torch.arange(seq_len).type(torch.float)[None]
        indices = torch.arange(self.output_dim // 2).type(torch.float)
        indices = torch.pow(10000.0, -2 * indices / self.output_dim)
        embeddings = torch.einsum("bn,d->bnd", position_ids, indices)
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim))
        if self.merge_mode == "add":
            return inputs + embeddings.to(inputs.device)
        elif self.merge_mode == "mul":
            return inputs * (embeddings + 1.0).to(inputs.device)
        elif self.merge_mode == "zero":
            return embeddings.to(inputs.device)


def multilabel_categorical_crossentropy(y_pred, y_true):
    y_pred = (1 - 2 * y_true) * y_pred  # -1 -> pos classes, 1 -> neg classes
    y_pred_neg = y_pred - y_true * 1e12  # mask the pred outputs of pos classes
    y_pred_pos = y_pred - (1 - y_true) * 1e12  # mask the pred outputs of neg classes
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    # print(y_pred, y_true, pos_loss)
    return (neg_loss + pos_loss).mean()


def multilabel_categorical_crossentropy2(y_pred, y_true):
    y_pred = (1 - 2 * y_true) * y_pred  # -1 -> pos classes, 1 -> neg classes
    y_pred_neg = y_pred.clone()
    y_pred_pos = y_pred.clone()
    y_pred_neg[y_true>0] -= float("inf")
    y_pred_pos[y_true<1] -= float("inf")
    # y_pred_neg = y_pred - y_true * float("inf")  # mask the pred outputs of pos classes
    # y_pred_pos = y_pred - (1 - y_true) * float("inf")  # mask the pred outputs of neg classes
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    # print(y_pred, y_true, pos_loss)
    return (neg_loss + pos_loss).mean()

@dataclass
class GlobalPointerOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    topk_probs: torch.FloatTensor = None
    topk_indices: torch.IntTensor = None
    last_hidden_state: torch.FloatTensor = None


@dataclass
class SpanProtoOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    query_spans: list = None
    proto_logits: list = None
    topk_probs: torch.FloatTensor = None
    topk_indices: torch.IntTensor = None


class SpanDetector(BertPreTrainedModel):
    def __init__(self, config):
        # encodr: RoBerta-Large as encoder
        # inner_dim: 64
        # ent_type_size: ent_cls_num
        super().__init__(config)
        self.bert = BertModel(config)
        # self.ent_type_size = config.ent_type_size
        self.ent_type_size = 1
        self.inner_dim = 64
        self.hidden_size = config.hidden_size
        self.RoPE = True

        self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
        self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2)  # 原版的dense2是(inner_dim * 2, ent_type_size * 2)


    def sequence_masking(self, x, mask, value="-inf", axis=None):
        if mask is None:
            return x
        else:
            if value == "-inf":
                value = -1e12
            elif value == "inf":
                value = 1e12
            assert axis > 0, "axis must be greater than 0"
            for _ in range(axis - 1):
                mask = torch.unsqueeze(mask, 1)
            for _ in range(x.ndim - mask.ndim):
                mask = torch.unsqueeze(mask, mask.ndim)
            return x * mask + value * (1 - mask)

    def add_mask_tril(self, logits, mask):
        if mask.dtype != logits.dtype:
            mask = mask.type(logits.dtype)
        logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
        logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
        # 排除下三角
        mask = torch.tril(torch.ones_like(logits), diagonal=-1)
        logits = logits - mask * 1e12
        return logits

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
        # with torch.no_grad():
        context_outputs = self.bert(input_ids, attention_mask, token_type_ids)
        last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
        del context_outputs
        outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
        qw, kw = outputs[..., ::2], outputs[..., 1::2]  # 从0,1开始间隔为2 最后一个维度,从0开始,取奇数位置所有向量汇总
        batch_size = input_ids.shape[0]
        if self.RoPE: # 是否使用RoPE旋转位置编码
            pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
            cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
            sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
            qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
            qw2 = torch.reshape(qw2, qw.shape)
            qw = qw * cos_pos + qw2 * sin_pos
            kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
            kw2 = torch.reshape(kw2, kw.shape)
            kw = kw * cos_pos + kw2 * sin_pos
        logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
        bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
        logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None]  # logits[:, None] 增加一个维度
        # logit_mask = self.add_mask_tril(logits, mask=attention_mask)
        loss = None

        mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
        # mask = torch.where(mask > 0, 0.0, 1)
        if labels is not None:
            # y_pred = torch.zeros(input_ids.shape[0], self.ent_type_size, input_ids.shape[1], input_ids.shape[1], device=input_ids.device)
            # for i in range(input_ids.shape[0]):
            #     for j in range(self.ent_type_size):
            #         y_pred[i, j, labels[i, j, 0], labels[i, j, 1]] = 1
            # y_true = labels.reshape(input_ids.shape[0] * self.ent_type_size, -1)
            # y_pred = logit_mask.reshape(input_ids.shape[0] * self.ent_type_size, -1)
            # loss = multilabel_categorical_crossentropy(y_pred, y_true)
            #

            # weight = ((labels == 0).sum() / labels.sum())/5
            # loss_fct = nn.BCEWithLogitsLoss(weight=weight)
            # loss_fct = nn.BCEWithLogitsLoss(reduction="none")
            # unmask_labels = labels.view(-1)[mask.view(-1) > 0]
            # loss = loss_fct(logits.view(-1)[mask.view(-1) > 0], unmask_labels.float())
            # if unmask_labels.sum() > 0:
            #     loss = (loss[unmask_labels > 0].mean()+loss[unmask_labels < 1].mean())/2
            # else:
            #     loss = loss[unmask_labels < 1].mean()
            # y_pred = logits.view(-1)[mask.view(-1) > 0]
            # y_true = labels.view(-1)[mask.view(-1) > 0]
            # loss = multilabel_categorical_crossentropy2(y_pred, y_true)
            # y_pred = logits - torch.where(mask > 0, 0.0, float("inf")).unsqueeze(1)
            y_pred = logits - (1-mask.unsqueeze(1))*1e12
            y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
            y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
            loss = multilabel_categorical_crossentropy(y_pred, y_true)

        with torch.no_grad():
            prob = torch.sigmoid(logits) * mask.unsqueeze(1)
            topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)


        return GlobalPointerOutput(
            loss=loss,
            topk_probs=topk.values,
            topk_indices=topk.indices,
            last_hidden_state=last_hidden_state
        )


class SpanProto(nn.Module):
    def __init__(self, config):
        """
        word_encoder: Sentence encoder

        You need to set self.cost as your own loss function.
        """
        nn.Module.__init__(self)
        self.config = config
        self.output_dir = "./outputs"
        # self.predict_dir = self.predict_result_path(self.output_dir)
        self.drop = nn.Dropout()
        self.global_span_detector = SpanDetector(config=self.config) # global span detector
        self.projector = nn.Sequential( # projector
            nn.Linear(self.config.hidden_size, self.config.hidden_size),
            nn.Sigmoid(),
            # nn.LayerNorm(2)
        )
        self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
        # self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.max_length = 64
        self.margin_distance = 6.0
        self.global_step = 0

    def predict_result_path(self, path=None):
        if path is None:
            predict_dir = os.path.join(
                self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
            )
        else:
            predict_dir = os.path.join(
                path, "predict"
            )
        # if os.path.exists(predict_dir):
        #     os.rmdir(predict_dir) # 删除历史记录
        if not os.path.exists(predict_dir): # 重新创建一个新的目录
            os.makedirs(predict_dir)
        return predict_dir


    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
        config = kwargs.pop("config", None)
        model = SpanProto(config=config)
        # 将bert部分参数加载进去
        model.global_span_detector = SpanDetector.from_pretrained(
            pretrained_model_name_or_path,
            *model_args,
            **kwargs
        )
        # 将剩余的参数加载进来
        return model

    # @classmethod
    # def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
    #     self.global_span_detector.resize_token_embeddings(new_num_tokens)

    def __dist__(self, x, y, dim, use_dot=False):
        # x: [1, class_num, hidden_dim], y: [span_num, 1, hidden_dim]
        # x - y: [span_num, class_num, hidden_dim]
        # (x - y)^2.sum(2): [span_num, class_num]
        if use_dot:
            return (x * y).sum(dim)
        else:
            return -(torch.pow(x - y, 2)).sum(dim)

    def __get_proto__(self, support_emb: torch, support_span: list, support_span_type: list, use_tag=False):
        """
        support_emb: [n", seq_len, dim]
        support_span: [n", m, 2] e.g. [[[3, 6], [12, 13]], [[1, 3]], ...]
        support_span_type: [n", m] e.g. [[2, 1], [5], ...]
        """
        prototype = list() # 每个类的proto type
        all_span_embs = list() # 保存每个span的embedding
        all_span_tags = list()
        # 遍历每个类
        for tag in range(self.num_class):
            # tag_id = torch.Tensor([1 if tag == self.num_class else 0]).long().cuda()
            # tag_embeddings = self.tag_embeddings(tag_id).view(-1)
            tag_prototype = list() # [k, dim]
            # 遍历当前episode内的每个句子
            for emb, span, type in zip(support_emb, support_span, support_span_type):
                # emb: [seq_len, dim], span: [m, 2], type: [m]
                span = torch.Tensor(span).long().cuda() # e.g. [[3, 4], [9, 11]]
                type = torch.Tensor(type).long().cuda() # e.g. [1, 4]
                # 获取当前句子中属于tag类的span
                try:
                    tag_span = span[type == tag] # e.g. span==[[3, 4]], tag==1

                    # 遍历每个检索到的span,获得其span embedding
                    for (s, e) in tag_span:
                        # tag_emb = torch.cat([emb[s], emb[e - 1]]) # [2*dim]
                        tag_emb = emb[s] + emb[e] # [dim]
                        # if use_tag: # 添加是否为unlabeled的标记,0对应embedding表示当前的span是labeled span,否则为unlabeled span
                        #     tag_emb = tag_emb + tag_embeddings
                        tag_prototype.append(tag_emb)
                        all_span_embs.append(tag_emb)
                        all_span_tags.append(tag)
                except:
                    # 说明当前类不存在对应的span,则随机
                    tag_prototype.append(torch.randn(support_emb.shape[-1]).cuda())
                    # assert 1 > 2
            try:
                prototype.append(torch.mean(torch.stack(tag_prototype), dim=0))
            except:
                # print("the class {} has no span".format(tag))
                prototype.append(torch.randn(support_emb.shape[-1]).cuda())
                # assert 1 > 2
        all_span_embs = torch.stack(all_span_embs).detach().cpu().numpy().tolist()

        return torch.stack(prototype), all_span_embs, all_span_tags # [num_class + 1, dim]


    def __batch_dist__(self, prototype: torch, query_emb: torch, query_spans: list, query_span_type: Union[list, None]):
        """
        该函数用于获得query到各个prototype的分类
        """
        # 首先获得当前episode的每个句子的每个span的表征向量
        # 遍历每个句子
        all_logits = list() # 保存每个episode,每个句子所有span的预测概率
        all_types = list()
        visual_all_types, visual_all_embs = list(), list() # 用于展示可视化
        # num = 0
        for emb, span in zip(query_emb, query_spans): # 遍历每个句子
            # assert len(span) == len(query_span_type[num]), "span={}\ntype{}".format(span, query_span_type[num])
            # print("len(span)={}, len(type)= {}".format(len(span), len(query_span_type[num])))
            span_emb = list()  # 保存当前句子所有span的embedding [m", dim]
            try:
                for (s, e) in span: # 遍历每个span
                    tag_emb = emb[s] + emb[e]  # [dim]
                    span_emb.append(tag_emb)
            except:
                span_emb = []
            if len(span_emb) != 0:
                span_emb = torch.stack(span_emb) # [span_num, dim]
                # 每个span与prototype计算距离
                logits = self.__dist__(prototype.unsqueeze(0), span_emb.unsqueeze(1), 2) # [span_num, num_class]
                # pred_types = torch.argmax(logits, -1).detach().cpu().numpy().tolist()
                with torch.no_grad():
                    pred_dist, pred_types = torch.max(logits, -1) # 获得每个query与所有prototype的距离的最近的类及其距离的平方
                    pred_dist = torch.pow(-1 * pred_dist, 0.5)
                    # print("pred_dist=", pred_dist)
                    # 如果最近的距离超过了margin distant,则该span视为unlabeled span,标注为特殊的类
                    pred_types[pred_dist > self.margin_distance] = self.num_class
                    pred_types = pred_types.detach().cpu().numpy().tolist()
                # # 获得概率分布
                # with torch.no_grad():
                #     prob = torch.softmax(logits, -1)
                #     pred_proba, pred_types = torch.max(logits, -1)  # 获得每个span预测概率最大的类及其概率
                #     pred_types[pred_proba <= 0.6] = self.num_class # 如果当前预测的最大概率不满足,则说明其可能是一个其他实体
                #     pred_types = pred_types.detach().cpu().numpy().tolist()

                all_logits.append(logits)
                all_types.append(pred_types)
                visual_all_types.extend(pred_types)
                visual_all_embs.extend(span_emb.detach().cpu().numpy().tolist())
            else:
                all_logits.append([])
                all_types.append([])
            # num += 1

        if query_span_type is not None:
            # query_span_type: [n", m]
            try:
                all_type = torch.Tensor([type for types in query_span_type for type in types]).long().cuda() # [span_num]
                loss = nn.CrossEntropyLoss()(torch.cat(all_logits, 0), all_type)
            except:
                all_logit, all_type = list(), list()
                for logits, types in zip(all_logits, query_span_type):
                    if len(logits) != 0 and len(types) != 0 and len(logits) == len(types):
                        # print("len(logits)=", len(logits))
                        # print("len(types)=", len(types))
                        # print("logits=", logits)
                        all_logit.append(logits)
                        all_type.extend(types)
                # print("all_logit=", all_logit)
                if len(all_logit) != 0:
                    all_logit = torch.cat(all_logit, 0)
                    all_type = torch.Tensor(all_type).long().cuda()
                    # print("len(all_logits)=", len(all_logits))
                    # print("len(query_span_type)=", len(query_span_type))

                    # print("types.shape=", torch.Tensor(all_type).shape)

                    # min_len = min(len(all_type), len(all_type))
                    # all_logit, all_type = all_logit[: min_len], all_type[: min_len]
                    # print("logits.shape=", all_logit.shape)
                    # print("all_type=", all_type)
                    loss = nn.CrossEntropyLoss()(all_logit, all_type)
                else:
                    loss = 0.


        else:
            loss = None
        all_logits = [i.detach().cpu().numpy().tolist() for i in all_logits if len(i) != 0]
        return loss, all_logits, all_types, visual_all_types, visual_all_embs


    def __batch_margin__(self, prototype: torch, query_emb: torch, query_unlabeled_spans: list,
                         query_labeled_spans: list, query_span_type: list):
        """
        该函数用于拉开unlabeled span与各个prototype的距离,拉近labeled span到对应类别的距离
        """

        # prototype: [num_class, dim], negative: [span_num, dim]
        # 获得每个unlabeled span与每个prototype的距离的平方,目标是对于每个距离平方都要设置大于margin阈值
        def distance(input1, input2, p=2, eps=1e-6):
            # Compute the distance (p-norm)
            norm = torch.pow(torch.abs((input1 - input2 + eps)), p)
            pnorm = torch.pow(torch.sum(norm, -1), 1.0 / p)
            return pnorm

        unlabeled_span_emb, labeled_span_emb, labeled_span_type = list(), list(), list()
        for emb, span in zip(query_emb, query_unlabeled_spans): # 遍历每个句子
              # 保存当前句子所有span的embedding [m", dim]
            for (s, e) in span: # 遍历每个span
                tag_emb = emb[s] + emb[e]  # [dim]
                unlabeled_span_emb.append(tag_emb)

        # for emb, span, type in zip(query_emb, query_labeled_spans, query_span_type): # 遍历每个句子
        #       # 保存当前句子所有span的embedding [m", dim]
        #     for (s, e) in span: # 遍历每个span
        #         tag_emb = emb[s] + emb[e]  # [dim]
        #         labeled_span_emb.append(tag_emb)
        #     labeled_span_type.extend(type)

        try:
            unlabeled_span_emb = torch.stack(unlabeled_span_emb) # [span_num, dim]
            # labeled_span_emb = torch.stack(labeled_span_emb) # [span_num, dim]
            # labeled_span_type = torch.stack(labeled_span_type) # [span_num]
        except:
            return 0.

        unlabeled_dist = distance(prototype.unsqueeze(0), unlabeled_span_emb.unsqueeze(1)) # [span_num, num_class]
        # labeled_dist = distance(prototype.unsqueeze(0), labeled_span_emb.unsqueeze(1)) # [span_num, num_class]
        # 获得每个span对应ground truth类别距离prototype的距离
        # labeled_type_dist = torch.gather(labeled_dist, -1, labeled_span_type.unsqueeze(1)) # [span_num, 1]
        # print(dist)
        unlabeled_output = torch.maximum(torch.zeros_like(unlabeled_dist), self.margin_distance - unlabeled_dist)
        # labeled_output = torch.maximum(torch.zeros_like(labeled_type_dist), labeled_type_dist)
        # return torch.mean(unlabeled_output) + torch.mean(labeled_output)
        return torch.mean(unlabeled_output)


    def forward(
            self,
            episode_ids,
            support, query,
            num_class,
            num_example,
            mode=None,
            short_labels=None,
            stage:str ="train",
            path: str=None
    ):
        """
        episode_ids: Input of the idx of each episode data. (only list)
        support: Inputs of the support set.
        query: Inputs of the query set.
        num_class: Num of classes
        K: Num of instances for each class in the support set
        Q: Num of instances for each class in the query set
        return: logits, pred
        """
        if stage.startswith("train"):
            self.global_step += 1
        self.num_class = num_class # N-way K-shot里的N
        self.num_example = num_example # N-way K-shot里的K
        # print("num_class=", num_class)
        self.mode = mode # FewNERD mode=inter/intra
        self.max_length = support["input_ids"].shape[1]
        support_inputs, support_attention_masks, support_type_ids = \
            support["input_ids"], support["attention_mask"], support["token_type_ids"] # torch, [n, seq_len]
        query_inputs, query_attention_masks, query_type_ids = \
            query["input_ids"], query["attention_mask"], query["token_type_ids"] # torch, [n, seq_len]
        support_labels = support["labels"] # torch,
        query_labels = query["labels"] # torch,
        # global span detector: obtain all mention span and loss
        support_detector_outputs = self.global_span_detector(
            support_inputs, support_attention_masks, support_type_ids, support_labels, short_labels=short_labels
        )
        query_detector_outputs = self.global_span_detector(
            query_inputs, query_attention_masks, query_type_ids, query_labels, short_labels=short_labels
        )
        device_id = support_inputs.device.index

        # if stage == "train_span":
        if self.global_step <= 500 and stage == "train":
            # only train span detector
            return SpanProtoOutput(
                loss=support_detector_outputs.loss,
                topk_probs=query_detector_outputs.topk_probs,
                topk_indices=query_detector_outputs.topk_indices,
            )
        # obtain labeled span from the support set
        support_labeled_spans = support["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end)
        support_labeled_types = support["labeled_types"] # all labeled ent type id, list, [n, m],
        query_labeled_spans = query["labeled_spans"]  # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end)
        query_labeled_types = query["labeled_types"]  # all labeled ent type id, list, [n, m],

        # for span, type in zip(query_labeled_spans, query_labeled_types): # 遍历每个句子
        #     assert len(span) == len(type), "span={}\ntype{}".format(span, type)

        # obtain unlabeled span from the support set
        # according to the detector, we can obtain multiple unlabeled span, which generated by the detector
        # but not labeled in n-way k-shot episode
        # support_predict_spans = self.get_topk_spans( #
        #     support_detector_outputs.topk_probs,
        #     support_detector_outputs.topk_indices,
        #     support["input_ids"]
        # ) # [n, m, 2]
        # print("predicted support span num={}".format([len(i) for i in support_predict_spans]))
        # e.g. 打印一个所有句子,每个元素表示每个句子中的span个数,[5, 50, 4, 43, 5, 5, 1, 50, 2, 5, 6, 4, 50, 8, 12, 28, 17]

        # we can also obtain all predicted span from the query set
        query_predict_spans = self.get_topk_spans(  #
            query_detector_outputs.topk_probs,
            query_detector_outputs.topk_indices,
            query["input_ids"],
            threshold=0.9 if stage.startswith("train") else 0.95,
            is_query=True
        )  # [n, m, 2]
        # print("predicted query span num={}".format([len(i) for i in query_predict_spans]))


        # merge predicted span and labeled span, and generate other class for unlabeled span set
        # support_all_spans, support_span_types = self.merge_span(
        #     labeled_spans=support_labeled_spans,
        #     labeled_types=support_labeled_types,
        #     predict_spans=support_predict_spans,
        #     stage=stage
        # ) # [n, m, 2] n 个句子,每个句子有若干个span
        # print("merged support span num={}".format([len(i) for i in support_all_spans]))


        if stage.startswith("train"):
            # 在训练阶段,需要知道detector识别的所有区间中,哪些是labeled,哪些是unlabeled,将unlabeled span全部分离出来
            query_unlabeled_spans = self.split_span( # 拆分出unlabeled span,用于后面的margin loss
                labeled_spans=query_labeled_spans,
                labeled_types=query_labeled_types,
                predict_spans=query_predict_spans,
                stage=stage
            )  # [n, m, 2] n 个句子,每个句子有若干个span
            # print("merged query span num={}".format([len(i) for i in query_all_spans]))
            query_all_spans = query_labeled_spans
            query_span_types = query_labeled_types

        else:
            # 在推理阶段,直接全部merge
            query_unlabeled_spans = None
            query_all_spans, _ = self.merge_span(
                labeled_spans=query_labeled_spans,
                labeled_types=query_labeled_types,
                predict_spans=query_predict_spans,
                stage=stage
            )  # [n, m, 2] n 个句子,每个句子有若干个span
            # 在dev和test时,此时query部分的span完全靠detector识别
            # query_all_spans = query_predict_spans
            query_span_types = None
            # 用于查看推理阶段dev或test的query上detector的预测结果
            # for query_label, query_pred in zip(query_labeled_spans, query_predict_spans):
            #     print(" ==== ")
            #     print("query_labeled_spans=", query_label)
            #     print("query_predict_spans=", query_pred)

        # obtain representations of each token
        support_emb, query_emb = support_detector_outputs.last_hidden_state, \
                                 query_detector_outputs.last_hidden_state # [n, seq_len, dim]
        support_emb, query_emb = self.projector(support_emb), self.projector(query_emb) # [n, seq_len, dim]

        # all_query_spans = list() # 保存每个episode的所有句子所有的预测span
        # all_proto_logits = list() # 保存每个episode的所有句子每个预测span对应的entity type
        batch_result = dict()
        proto_losses = list() # 保存每个episode的loss
        # batch_visual = list() # 保存每个episode所有span的表征向量,用于可视化
        current_support_num = 0
        current_query_num = 0
        typing_loss = None
        # 遍历每个episode
        for i, sent_support_num in enumerate(support["sentence_num"]):
            sent_query_num = query["sentence_num"][i]
            id_ = episode_ids[i] # 当前episode的编号

            # 对于support,只对labeled span获得prototype
            # locate one episode and obtain the span prototype
            # [n", seq_len, dim] n" sentence in one episode
            # support_proto [num_class + 1, dim]
            support_proto, all_span_embs, all_span_tags = self.__get_proto__(
                support_emb[current_support_num: current_support_num + sent_support_num], # [n", seq_len, dim]
                support_labeled_spans[current_support_num: current_support_num + sent_support_num],  # [n", m]
                support_labeled_types[current_support_num: current_support_num + sent_support_num],  # [n", m]
            )


            # 对于query set每个labeled span,使用标准的prototype learning
            # for each query, we first obtain corresponding span, and then calculate distance between it and each prototype
            # # [n", seq_len, dim] n" sentence in one episode
            proto_loss, proto_logits, all_types, visual_all_types, visual_all_embs = self.__batch_dist__(
                support_proto,
                query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim]
                query_all_spans[current_query_num: current_query_num + sent_query_num],  # [n", m]
                query_span_types[current_query_num: current_query_num + sent_query_num] if query_span_types else None,  # [n", m]
            )

            visual_data = {
                "data": all_span_embs + visual_all_embs,
                "target": all_span_tags + visual_all_types,
            }

            # 对于query unlabeled span,遍历每个span,拉开与所有prototype的距离,选择margin loss
            if stage.startswith("train"):

                margin_loss = self.__batch_margin__(
                    support_proto,
                    query_emb[current_query_num: current_query_num + sent_query_num],  # [n", seq_len, dim]
                    query_unlabeled_spans[current_query_num: current_query_num + sent_query_num],  # [n", span_num]
                    query_all_spans[current_query_num: current_query_num + sent_query_num],
                    query_span_types[current_query_num: current_query_num + sent_query_num],
                )

                proto_losses.append(proto_loss + margin_loss)

            batch_result[id_] = {
                "spans": query_all_spans[current_query_num: current_query_num + sent_query_num],
                "types": all_types,
                "visualization": visual_data
            }

            current_query_num += sent_query_num
            current_support_num += sent_support_num
        # proto_logits = torch.stack(proto_logits)
        if stage.startswith("train"):
            typing_loss = torch.mean(torch.stack(proto_losses), dim=-1)


        if not stage.startswith("train"):
            self.__save_evaluate_predicted_result__(batch_result, device_id=device_id, stage=stage, path=path)

        # return SpanProtoOutput(
        #         loss=((support_detector_outputs.loss + query_detector_outputs.loss) / 2.0 + typing_loss)
        #         if stage.startswith("train") else (support_detector_outputs.loss + query_detector_outputs.loss),
        #     ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错
        return SpanProtoOutput(
            loss=(support_detector_outputs.loss + typing_loss)
            if stage.startswith("train") else query_detector_outputs.loss,
        )  # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错

    def __save_evaluate_predicted_result__(self, new_result: dict, device_id: int = 0, stage="dev", path=None):
        """
        本函数用于在forward时保存每一个batch内的预测span以及span type
        new_result / result: {
            "(id)": { # id-th episode query
                "spans": [[[1, 4], [6, 7], xxx], ... ] # [sent_num, span_num, 2]
                "types": [[2, 0, xxx], ...] # [sent_num, span_num]
            },
            xxx
        }
        """
        # 拉取当前任务中已经预测的结果
        self.predict_dir = self.predict_result_path(path)
        npy_file_name = os.path.join(self.predict_dir, "{}_predictions_{}.npy".format(stage, device_id))
        result = dict()
        if os.path.exists(npy_file_name):
            result = np.load(npy_file_name, allow_pickle=True)[()]
        # 合并
        for episode_id, query_res in new_result.items():
            result[episode_id] = query_res
        # 保存
        np.save(npy_file_name, result, allow_pickle=True)


    def get_topk_spans(self, probs, indices, input_ids, threshold=0.60, low_threshold=0.1, is_query=False):
        """
        probs: [n, m]
        indices: [n, m]
        input_texts: [n, seq_len]
        is_query: if true, each sentence must recall at least one span
        """
        probs = probs.squeeze(1).detach().cpu()  # topk结果的概率 [n, m]  # 返回的已经是按照概率进行降序排列的结果
        indices = indices.squeeze(1).detach().cpu()  # topk结果的索引 [n, m]  # 返回的已经是按照概率进行降序排列的结果
        input_ids = input_ids.detach().cpu()
        # print("probs=", probs) # [n, m]
        # print("indices=", indices) # [n, m]
        predict_span = list()
        if is_query:
            low_threshold = 0.0
        for prob, index, text in zip(probs, indices, input_ids): # 遍历每个句子,其对应若干预测的span及其概率
            threshold_ = threshold
            index_ids = torch.Tensor([i for i in range(len(index))]).long()
            span = set()
            # TODO 1. 调节阈值 2. 处理输出实体重叠问题
            entity_index = index[prob >= low_threshold]
            index_ids = index_ids[prob >= low_threshold]
            while threshold_ >= low_threshold: # 动态控制阈值,以确保可以召回出span数量是尽可能均匀的(如果所有句子使用同一个阈值,那么每个句子被召回的span数量参差不齐)
                for ei, entity in enumerate(entity_index):
                    p = prob[index_ids[ei]]
                    if p < threshold_: # 如果此时候选的span得分已经低于阈值,由于获得的结果已经是降序排列的,则后续的结果一定都低于阈值,则直接结束
                        break
                    # 1D index转2D index
                    start_end = np.unravel_index(entity, (self.max_length, self.max_length))
                    # print("self.max_length=", self.max_length)
                    s, e = start_end[0], start_end[1]
                    ans = text[s: e]
                    # if ans not in answer:
                    #     answer.append(ans)
                    #     topk_answer_dict[ans] = {"prob": float(prob[index_ids[ei]]), "pos": [(s, e)]}
                    span.add((s, e))
                # 满足下列几个条件的,动态调低阈值,并重新筛选
                if len(span) <= 3:
                    threshold_ -= 0.05
                else:
                    break
            if len(span) == 0:
                # 如果当前没有召回出任何span,则直接选择[cls]作为结果(相当于MRC的unanswerable)
                span = [[0, 0]]
            span = [list(i) for i in list(span)]
            # print("prob=", prob) e.g. [0.96, 0.85, 0.04, 0.00, ...]
            # print("span=", span) e.g. [[20, 23], [11, 14]]
            predict_span.append(span)
        return predict_span


    def split_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"):
        """
        # 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span

        """
        def check_similar_span(span1, span2):
            """
            检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的
            """
            # 考虑一个特殊情况,例如 [12, 12], [13, 13]
            if len(span1) == 0 or len(span2) == 0:
                return False
            if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1:
                return False
            if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内
                return True
            return False

        all_spans, span_types = list(), list() # [n, m]
        num = 0
        unlabeled_spans = list()
        for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans):
            # 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span
            unlabeled_span = list()
            # if len(all_span) != len(span_type):
            #     length = min(len(all_span), len(span_type))
            #     all_span, span_type = all_span[: length], span_type[: length]
            for span in predict_span: # 遍历每个预测的span
                if span not in labeled_span: # 如果span没有存在,则说明当前的span是unlabeled的
                    # 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除
                    is_remove = False
                    for span_x in labeled_span: # 遍历所有已经被merge的span
                        is_remove = check_similar_span(span_x, span) # 如果已存在的span,和当前的span很接近,则排除当前的span
                        if is_remove is True:
                            break
                    if is_remove is True:
                        continue
                    unlabeled_span.append(span)
            # if self.global_step % 1000 == 0:
            #     print(" === ")
            #     print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]]
            #     print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]]
            # if len(unlabeled_span) == 0 and stage.startswith("train"):
            #     # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空
            #     # print("unlabeled span is empty, so we randomly select one span as the unlabeled span")
            #     # all_span.append([0, 0])
            #     # span_type.append(self.num_class)
            #     while True:
            #         random_span = np.random.randint(0, 32, 2).tolist()
            #         if abs(random_span[0] - random_span[1]) > 10:
            #             continue
            #         random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span
            #         if random_span in labeled_span or random_span in unlabeled_span:
            #             continue
            #         unlabeled_span.append(random_span)
            #         break
            num += len(unlabeled_span)
            unlabeled_spans.append(unlabeled_span)
        # print("num=", num)
        return unlabeled_spans


    def merge_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"):

        def check_similar_span(span1, span2):
            """
            检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的
            """
            # 考虑一个特殊情况,例如 [12, 12], [13, 13]
            if len(span1) == 0 or len(span2) == 0:
                return False
            if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1:
                return False
            if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内
                return True
            return False

        all_spans, span_types = list(), list() # [n, m]
        for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans):
            # 遍历每个句子,对它们的span进行合并
            unlabeled_num = 0
            all_span, span_type = labeled_span, labeled_type # 先加入所有labeled span
            if len(all_span) != len(span_type):
                length = min(len(all_span), len(span_type))
                all_span, span_type = all_span[: length], span_type[: length]
            for span in predict_span: # 遍历每个预测的span
                if span not in all_span: # 如果span没有存在,则说明当前的span是unlabeled的
                    # 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除
                    is_remove = False
                    for span_x in all_span: # 遍历所有已经被merge的span
                        is_remove = check_similar_span(span_x, span) # 如果已存在的span,和当前的span很接近,则排除当前的span
                        if is_remove is True:
                            break
                    if is_remove is True:
                        continue
                    all_span.append(span)
                    span_type.append(self.num_class) # e.g. 5-way问题,已有标签为0,1,2,3,4,因此新增一个标签为5
                    unlabeled_num += 1
            # if self.global_step % 1000 == 0:
            #     print(" === ")
            #     print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]]
            #     print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]]
            if unlabeled_num == 0 and stage.startswith("train"):
                # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空
                # print("unlabeled span is empty, so we randomly select one span as the unlabeled span")
                # all_span.append([0, 0])
                # span_type.append(self.num_class)
                while True:
                    random_span = np.random.randint(0, 32, 2).tolist()
                    if abs(random_span[0] - random_span[1]) > 10:
                        continue
                    random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span
                    if random_span in all_span:
                        continue
                    all_span.append(random_span)
                    span_type.append(self.num_class)
                    break

            # if len(all_span) != len(span_type):
            #     all_span = [[0, 0]]
            #     span_type = [self.num_class]

            all_spans.append(all_span)
            span_types.append(span_type)

        return all_spans, span_types