Update models/unet_gray2color.py
Browse files
models/unet_gray2color.py
CHANGED
@@ -26,9 +26,14 @@ class SelfAttentionLayer(Layer):
|
|
26 |
self.ln = LayerNormalization(epsilon=1e-6)
|
27 |
|
28 |
def build(self, input_shape):
|
29 |
-
#
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
32 |
super(SelfAttentionLayer, self).build(input_shape)
|
33 |
|
34 |
def call(self, x):
|
|
|
26 |
self.ln = LayerNormalization(epsilon=1e-6)
|
27 |
|
28 |
def build(self, input_shape):
|
29 |
+
# input_shape: (batch_size, height, width, channels)
|
30 |
+
# For self-attention, query, key, and value have the same shape
|
31 |
+
batch_size, height, width, channels = input_shape
|
32 |
+
attention_shape = (batch_size, height * width, channels) # Shape after reshape
|
33 |
+
# Build MultiHeadAttention with query, key, and value shapes
|
34 |
+
self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
|
35 |
+
# Build LayerNormalization with the original input shape
|
36 |
+
self.ln.build(input_shape)
|
37 |
super(SelfAttentionLayer, self).build(input_shape)
|
38 |
|
39 |
def call(self, x):
|