File size: 678 Bytes
310a06c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# import the necessary packages
from tensorflow import keras
import tensorflow as tf

# Patch conv
class PatchConvNet(keras.Model):
	def __init__(
		self,
		stem,
		trunk,
		attention_pooling,
		**kwargs,
	):
		super().__init__(**kwargs)
		self.stem = stem
		self.trunk = trunk
		self.attention_pooling = attention_pooling

	@tf.function(
	input_signature=[
		tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
	])
	def call(self, images):
		# pass through the stem
		x = self.stem(images)
		# pass through the trunk
		x = self.trunk(x)
		# pass through the attention pooling block
		predictions, viz_weights = self.attention_pooling(x)
		return predictions, viz_weights