danhtran2mind commited on
Commit
772dbaa
·
verified ·
1 Parent(s): 61e51d0

Update models/unet_gray2color.py

Browse files
Files changed (1) hide show
  1. models/unet_gray2color.py +8 -3
models/unet_gray2color.py CHANGED
@@ -26,9 +26,14 @@ class SelfAttentionLayer(Layer):
26
  self.ln = LayerNormalization(epsilon=1e-6)
27
 
28
  def build(self, input_shape):
29
- # Initialize the MultiHeadAttention and LayerNormalization layers
30
- self.mha.build(input_shape) # Build MultiHeadAttention
31
- self.ln.build(input_shape) # Build LayerNormalization
 
 
 
 
 
32
  super(SelfAttentionLayer, self).build(input_shape)
33
 
34
  def call(self, x):
 
26
  self.ln = LayerNormalization(epsilon=1e-6)
27
 
28
  def build(self, input_shape):
29
+ # input_shape: (batch_size, height, width, channels)
30
+ # For self-attention, query, key, and value have the same shape
31
+ batch_size, height, width, channels = input_shape
32
+ attention_shape = (batch_size, height * width, channels) # Shape after reshape
33
+ # Build MultiHeadAttention with query, key, and value shapes
34
+ self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
35
+ # Build LayerNormalization with the original input shape
36
+ self.ln.build(input_shape)
37
  super(SelfAttentionLayer, self).build(input_shape)
38
 
39
  def call(self, x):