change cuz model parallel was not working..
Browse files- modular_Sewy2.py +2 -0
modular_Sewy2.py
CHANGED
@@ -2042,6 +2042,8 @@ class SewyV2ForCausalLM(SewyV2PreTrainedModel):
|
|
2042 |
|
2043 |
s_z = self.s_z * (self.s_z_init/self.s_z_scale)
|
2044 |
|
|
|
|
|
2045 |
logits = logits * s_z.view(1, 1, -1)
|
2046 |
|
2047 |
loss = None
|
|
|
2042 |
|
2043 |
s_z = self.s_z * (self.s_z_init/self.s_z_scale)
|
2044 |
|
2045 |
+
s_z = s_z.to(logits.device)
|
2046 |
+
|
2047 |
logits = logits * s_z.view(1, 1, -1)
|
2048 |
|
2049 |
loss = None
|