ypesk commited on
Commit
821e9d2
·
verified ·
1 Parent(s): b11f360

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +6 -6
tasks/text.py CHANGED
@@ -37,7 +37,7 @@ class ConspiracyClassification768(
37
  PyTorchModelHubMixin,
38
  # optionally, you can add metadata which gets pushed to the model card
39
  ):
40
- def __init__(self, num_classes):
41
  super().__init__()
42
  self.h1 = nn.Linear(768, 100)
43
  self.h2 = nn.Linear(100, 100)
@@ -70,7 +70,7 @@ class CTBERT(
70
  PyTorchModelHubMixin,
71
  # optionally, you can add metadata which gets pushed to the model card
72
  ):
73
- def __init__(self, num_classes):
74
  super().__init__()
75
  self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
76
  self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
@@ -86,7 +86,7 @@ class conspiracyModelBase(
86
  PyTorchModelHubMixin,
87
  # optionally, you can add metadata which gets pushed to the model card
88
  ):
89
- def __init__(self, num_classes):
90
  super().__init__()
91
  self.n_classes = num_classes
92
  self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base', num_labels=num_classes)
@@ -101,7 +101,7 @@ class conspiracyModelLarge(
101
  PyTorchModelHubMixin,
102
  # optionally, you can add metadata which gets pushed to the model card
103
  ):
104
- def __init__(self, num_classes):
105
  super().__init__()
106
  self.n_classes = num_classes
107
  self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
@@ -116,7 +116,7 @@ class gteModelLarge(
116
  PyTorchModelHubMixin,
117
  # optionally, you can add metadata which gets pushed to the model card
118
  ):
119
- def __init__(self, num_classes):
120
  super().__init__()
121
  self.n_classes = num_classes
122
  self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
@@ -133,7 +133,7 @@ class gteModel(
133
  PyTorchModelHubMixin,
134
  # optionally, you can add metadata which gets pushed to the model card
135
  ):
136
- def __init__(self, num_classes):
137
  super().__init__()
138
  self.n_classes = num_classes
139
  self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
 
37
  PyTorchModelHubMixin,
38
  # optionally, you can add metadata which gets pushed to the model card
39
  ):
40
+ def __init__(self, num_classes=8):
41
  super().__init__()
42
  self.h1 = nn.Linear(768, 100)
43
  self.h2 = nn.Linear(100, 100)
 
70
  PyTorchModelHubMixin,
71
  # optionally, you can add metadata which gets pushed to the model card
72
  ):
73
+ def __init__(self, num_classes=8):
74
  super().__init__()
75
  self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
76
  self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
 
86
  PyTorchModelHubMixin,
87
  # optionally, you can add metadata which gets pushed to the model card
88
  ):
89
+ def __init__(self, num_classes=8):
90
  super().__init__()
91
  self.n_classes = num_classes
92
  self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base', num_labels=num_classes)
 
101
  PyTorchModelHubMixin,
102
  # optionally, you can add metadata which gets pushed to the model card
103
  ):
104
+ def __init__(self, num_classes=8):
105
  super().__init__()
106
  self.n_classes = num_classes
107
  self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
 
116
  PyTorchModelHubMixin,
117
  # optionally, you can add metadata which gets pushed to the model card
118
  ):
119
+ def __init__(self, num_classes=8):
120
  super().__init__()
121
  self.n_classes = num_classes
122
  self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
 
133
  PyTorchModelHubMixin,
134
  # optionally, you can add metadata which gets pushed to the model card
135
  ):
136
+ def __init__(self, num_classes=8):
137
  super().__init__()
138
  self.n_classes = num_classes
139
  self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)