João Pedro commited on
Commit
23057e8
·
1 Parent(s): edcda91

dummy wandb to training code

Browse files
Files changed (1) hide show
  1. training.py +56 -39
training.py CHANGED
@@ -90,36 +90,28 @@ def prepare_dataset(
90
  .prefetch(tf.data.experimental.AUTOTUNE)
91
 
92
 
93
- metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH)
94
- metadata_df = metadata_df.sample(n=50, random_state=42)
95
-
96
- median_height = int(metadata_df['height'].median())
97
- median_width = int(metadata_df['width'].median())
98
- img_size: ImageSize = (median_height, median_width)
99
- img_input_shape: ImageInputShape = img_size + (3,)
100
-
101
- label_names: List[str] = sorted(
102
- [d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()]
103
- )
104
- num_classes = len(label_names)
105
 
106
- print('Splitting the DataFrame into training, validation and test')
107
- train_df, val_df, test_df = stratified_split(
108
- metadata_df,
109
- train_frac=0.7,
110
- val_frac=0.15,
111
- test_frac=0.15,
112
- )
113
 
114
- print('Batching and shuffling the datasets')
115
- train_ds = dataset_from_dataframe(train_df)
116
- train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE)
117
 
118
- val_ds = dataset_from_dataframe(val_df)
119
- val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE)
120
 
121
- test_ds = dataset_from_dataframe(test_df)
122
- test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE)
123
 
124
 
125
  def build_image_model(input_shape: ImageInputShape) -> keras.Model:
@@ -199,16 +191,41 @@ def build_multimodal_model(
199
  return multimodal_model
200
 
201
 
202
- multimodal_model = build_multimodal_model(num_classes, img_input_shape)
203
- multimodal_model.summary()
204
- multimodal_model.compile(
205
- optimizer='adam',
206
- loss='sparse_categorical_crossentropy',
207
- metrics=['accuracy']
208
- )
209
- multimodal_model.fit(
210
- train_ds,
211
- epochs=EPOCHS,
212
- batch_size=BATCH_SIZE,
213
- validation_data=val_ds,
214
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  .prefetch(tf.data.experimental.AUTOTUNE)
91
 
92
 
93
+ def prepare_data(
94
+ df: DataFrame[PageMetadata]
95
+ ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
96
+ print('Splitting the DataFrame into training, validation and test')
97
+ train_df, val_df, test_df = stratified_split(
98
+ df,
99
+ train_frac=0.7,
100
+ val_frac=0.15,
101
+ test_frac=0.15,
102
+ )
 
 
103
 
104
+ print('Batching and shuffling the datasets')
105
+ train_ds = dataset_from_dataframe(train_df)
106
+ train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE)
 
 
 
 
107
 
108
+ val_ds = dataset_from_dataframe(val_df)
109
+ val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE)
 
110
 
111
+ test_ds = dataset_from_dataframe(test_df)
112
+ test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE)
113
 
114
+ return train_ds, val_ds, test_ds
 
115
 
116
 
117
  def build_image_model(input_shape: ImageInputShape) -> keras.Model:
 
191
  return multimodal_model
192
 
193
 
194
+ def train():
195
+ metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH)
196
+
197
+ median_height = int(metadata_df['height'].median())
198
+ median_width = int(metadata_df['width'].median())
199
+
200
+ img_size: ImageSize = (median_height, median_width)
201
+ img_input_shape: ImageInputShape = img_size + (3,)
202
+
203
+ label_names: List[str] = sorted(
204
+ [d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()]
205
+ )
206
+ num_classes = len(label_names)
207
+
208
+ train_ds, val_ds, test_ds = prepare_data(metadata_df)
209
+
210
+ multimodal_model = build_multimodal_model(num_classes, img_input_shape)
211
+ multimodal_model.summary()
212
+ multimodal_model.compile(
213
+ optimizer='adam',
214
+ loss='sparse_categorical_crossentropy',
215
+ metrics=['accuracy']
216
+ )
217
+ multimodal_model.fit(
218
+ train_ds,
219
+ epochs=EPOCHS,
220
+ batch_size=BATCH_SIZE,
221
+ validation_data=val_ds,
222
+ )
223
+
224
+
225
+ def evaluate():
226
+ return
227
+
228
+
229
+ if __name__ = '__main__':
230
+ train()
231
+ evaluate()