nicolinho commited on
Commit
f1a53bb
·
verified ·
1 Parent(s): 47bb4a8

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +2 -3
modeling_custom.py CHANGED
@@ -38,14 +38,13 @@ class GatingNetwork(nn.Module):
38
  dropout_rate = dropout
39
  for i in range(n_hidden):
40
  layers.append(nn.Linear(in_features, hidden_dim, bias=False)) # for BN
41
- nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')
42
  layers.append(nn.ReLU())
43
  layers.append(nn.BatchNorm1d(hidden_dim))
44
  if dropout_rate > 0 and i < n_hidden - 1: # no dropout before last layer for more stability and precision
45
  layers.append(nn.Dropout(dropout_rate))
46
 
47
  in_features = hidden_dim
48
- # self.dropout_list = [nn.Dropout(dropout_rate), nn.Dropout(dropout_rate)]
49
  layers.append(nn.Linear(in_features, out_features, bias=bias))
50
  self.layers = nn.ModuleList(layers)
51
  # print("Gating network layers:", self.layers)
@@ -117,7 +116,7 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
117
 
118
  # Initialize weights and apply final processing
119
  self.gating = GatingNetwork(config.hidden_size, self.num_objectives,
120
- temperature=config_dict.get("gating_temperature", 1),
121
  hidden_dim=config_dict.get("gating_hidden_dim", 1024),
122
  n_hidden=config_dict.get("gating_n_hidden", 3))
123
 
 
38
  dropout_rate = dropout
39
  for i in range(n_hidden):
40
  layers.append(nn.Linear(in_features, hidden_dim, bias=False)) # for BN
41
+ #nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')
42
  layers.append(nn.ReLU())
43
  layers.append(nn.BatchNorm1d(hidden_dim))
44
  if dropout_rate > 0 and i < n_hidden - 1: # no dropout before last layer for more stability and precision
45
  layers.append(nn.Dropout(dropout_rate))
46
 
47
  in_features = hidden_dim
 
48
  layers.append(nn.Linear(in_features, out_features, bias=bias))
49
  self.layers = nn.ModuleList(layers)
50
  # print("Gating network layers:", self.layers)
 
116
 
117
  # Initialize weights and apply final processing
118
  self.gating = GatingNetwork(config.hidden_size, self.num_objectives,
119
+ temperature=config_dict.get("gating_temperature", 2),
120
  hidden_dim=config_dict.get("gating_hidden_dim", 1024),
121
  n_hidden=config_dict.get("gating_n_hidden", 3))
122