Update modeling_custom.py
Browse files- modeling_custom.py +2 -3
modeling_custom.py
CHANGED
@@ -38,14 +38,13 @@ class GatingNetwork(nn.Module):
|
|
38 |
dropout_rate = dropout
|
39 |
for i in range(n_hidden):
|
40 |
layers.append(nn.Linear(in_features, hidden_dim, bias=False)) # for BN
|
41 |
-
nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')
|
42 |
layers.append(nn.ReLU())
|
43 |
layers.append(nn.BatchNorm1d(hidden_dim))
|
44 |
if dropout_rate > 0 and i < n_hidden - 1: # no dropout before last layer for more stability and precision
|
45 |
layers.append(nn.Dropout(dropout_rate))
|
46 |
|
47 |
in_features = hidden_dim
|
48 |
-
# self.dropout_list = [nn.Dropout(dropout_rate), nn.Dropout(dropout_rate)]
|
49 |
layers.append(nn.Linear(in_features, out_features, bias=bias))
|
50 |
self.layers = nn.ModuleList(layers)
|
51 |
# print("Gating network layers:", self.layers)
|
@@ -117,7 +116,7 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
|
|
117 |
|
118 |
# Initialize weights and apply final processing
|
119 |
self.gating = GatingNetwork(config.hidden_size, self.num_objectives,
|
120 |
-
temperature=config_dict.get("gating_temperature",
|
121 |
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
|
122 |
n_hidden=config_dict.get("gating_n_hidden", 3))
|
123 |
|
|
|
38 |
dropout_rate = dropout
|
39 |
for i in range(n_hidden):
|
40 |
layers.append(nn.Linear(in_features, hidden_dim, bias=False)) # for BN
|
41 |
+
#nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')
|
42 |
layers.append(nn.ReLU())
|
43 |
layers.append(nn.BatchNorm1d(hidden_dim))
|
44 |
if dropout_rate > 0 and i < n_hidden - 1: # no dropout before last layer for more stability and precision
|
45 |
layers.append(nn.Dropout(dropout_rate))
|
46 |
|
47 |
in_features = hidden_dim
|
|
|
48 |
layers.append(nn.Linear(in_features, out_features, bias=bias))
|
49 |
self.layers = nn.ModuleList(layers)
|
50 |
# print("Gating network layers:", self.layers)
|
|
|
116 |
|
117 |
# Initialize weights and apply final processing
|
118 |
self.gating = GatingNetwork(config.hidden_size, self.num_objectives,
|
119 |
+
temperature=config_dict.get("gating_temperature", 2),
|
120 |
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
|
121 |
n_hidden=config_dict.get("gating_n_hidden", 3))
|
122 |
|