# import the necessary packages from tensorflow import keras import tensorflow as tf # Patch conv class PatchConvNet(keras.Model): def __init__( self, stem, trunk, attention_pooling, **kwargs, ): super().__init__(**kwargs) self.stem = stem self.trunk = trunk self.attention_pooling = attention_pooling @tf.function( input_signature=[ tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8) ]) def call(self, images): # pass through the stem x = self.stem(images) # pass through the trunk x = self.trunk(x) # pass through the attention pooling block predictions, viz_weights = self.attention_pooling(x) return predictions, viz_weights