Spaces:
Sleeping
Sleeping
João Pedro
commited on
Commit
·
23057e8
1
Parent(s):
edcda91
dummy wandb to training code
Browse files- training.py +56 -39
training.py
CHANGED
@@ -90,36 +90,28 @@ def prepare_dataset(
|
|
90 |
.prefetch(tf.data.experimental.AUTOTUNE)
|
91 |
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
)
|
104 |
-
num_classes = len(label_names)
|
105 |
|
106 |
-
print('
|
107 |
-
|
108 |
-
|
109 |
-
train_frac=0.7,
|
110 |
-
val_frac=0.15,
|
111 |
-
test_frac=0.15,
|
112 |
-
)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE)
|
117 |
|
118 |
-
|
119 |
-
|
120 |
|
121 |
-
|
122 |
-
test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE)
|
123 |
|
124 |
|
125 |
def build_image_model(input_shape: ImageInputShape) -> keras.Model:
|
@@ -199,16 +191,41 @@ def build_multimodal_model(
|
|
199 |
return multimodal_model
|
200 |
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
)
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
.prefetch(tf.data.experimental.AUTOTUNE)
|
91 |
|
92 |
|
93 |
+
def prepare_data(
|
94 |
+
df: DataFrame[PageMetadata]
|
95 |
+
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
|
96 |
+
print('Splitting the DataFrame into training, validation and test')
|
97 |
+
train_df, val_df, test_df = stratified_split(
|
98 |
+
df,
|
99 |
+
train_frac=0.7,
|
100 |
+
val_frac=0.15,
|
101 |
+
test_frac=0.15,
|
102 |
+
)
|
|
|
|
|
103 |
|
104 |
+
print('Batching and shuffling the datasets')
|
105 |
+
train_ds = dataset_from_dataframe(train_df)
|
106 |
+
train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE)
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
val_ds = dataset_from_dataframe(val_df)
|
109 |
+
val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE)
|
|
|
110 |
|
111 |
+
test_ds = dataset_from_dataframe(test_df)
|
112 |
+
test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE)
|
113 |
|
114 |
+
return train_ds, val_ds, test_ds
|
|
|
115 |
|
116 |
|
117 |
def build_image_model(input_shape: ImageInputShape) -> keras.Model:
|
|
|
191 |
return multimodal_model
|
192 |
|
193 |
|
194 |
+
def train():
|
195 |
+
metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH)
|
196 |
+
|
197 |
+
median_height = int(metadata_df['height'].median())
|
198 |
+
median_width = int(metadata_df['width'].median())
|
199 |
+
|
200 |
+
img_size: ImageSize = (median_height, median_width)
|
201 |
+
img_input_shape: ImageInputShape = img_size + (3,)
|
202 |
+
|
203 |
+
label_names: List[str] = sorted(
|
204 |
+
[d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()]
|
205 |
+
)
|
206 |
+
num_classes = len(label_names)
|
207 |
+
|
208 |
+
train_ds, val_ds, test_ds = prepare_data(metadata_df)
|
209 |
+
|
210 |
+
multimodal_model = build_multimodal_model(num_classes, img_input_shape)
|
211 |
+
multimodal_model.summary()
|
212 |
+
multimodal_model.compile(
|
213 |
+
optimizer='adam',
|
214 |
+
loss='sparse_categorical_crossentropy',
|
215 |
+
metrics=['accuracy']
|
216 |
+
)
|
217 |
+
multimodal_model.fit(
|
218 |
+
train_ds,
|
219 |
+
epochs=EPOCHS,
|
220 |
+
batch_size=BATCH_SIZE,
|
221 |
+
validation_data=val_ds,
|
222 |
+
)
|
223 |
+
|
224 |
+
|
225 |
+
def evaluate():
|
226 |
+
return
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ = '__main__':
|
230 |
+
train()
|
231 |
+
evaluate()
|