hydra-classifier / training.py
João Pedro
log the actual image to wandb
d2efd6d
import pandas as pd
import tensorflow as tf
import tf_keras as keras
from constants import (PROCESSED_DATA_DIR,
METADATA_FILEPATH,
BATCH_SIZE,
EPOCHS,
BERT_BASE,
MAX_SEQUENCE_LENGHT,
PROJECT_NAME,
FilePath,
PageMetadata,
ImageSize,
ImageInputShape)
from pandera.typing import DataFrame
from typing import Tuple, List
from transformers import TFBertModel
from tf_keras import layers, models
from PIL import Image
# Allow for unlimited image size, some documents are pretty big...
Image.MAX_IMAGE_PIXELS = None
def stratified_split(
df: pd.DataFrame,
train_frac: float,
val_frac: float,
test_frac: float,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
train_dfs, val_dfs, test_dfs = [], [], []
for label, group in df.groupby('label'):
n = len(group)
train_end = int(n * train_frac)
val_end = train_end + int(n * val_frac)
train_dfs.append(group.iloc[:train_end])
val_dfs.append(group.iloc[train_end:val_end])
test_dfs.append(group.iloc[val_end:])
train_df = pd.concat(train_dfs).reset_index(drop=True)
val_df = pd.concat(val_dfs).reset_index(drop=True)
test_df = pd.concat(test_dfs).reset_index(drop=True)
return train_df, val_df, test_df
def dataset_from_dataframe(df: pd.DataFrame) -> tf.data.Dataset:
return tf.data.Dataset.from_tensor_slices((
df['img_filepath'].values,
df['input_ids'].values,
df['attention_mask'].values,
df['label'].values,
))
def load_image(image_path: FilePath, image_size: ImageSize) -> Image:
img_width, img_height = image_size
# Load image
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [img_width, img_height])
image /= 255.0
return image
def prepare_dataset(
ds: tf.data.Dataset,
image_size: ImageSize,
batch_size=32,
buffer_size=1000
) -> tf.data.Dataset:
def load_image_and_format_tensor_shape(
img_path: FilePath,
input_ids: List[int],
attention_mask: List[int],
label: str
):
image = load_image(img_path, image_size)
return ((image, input_ids, attention_mask), label)
return ds.map(
load_image_and_format_tensor_shape,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
) \
.shuffle(buffer_size=buffer_size) \
.batch(batch_size) \
.prefetch(tf.data.experimental.AUTOTUNE)
def prepare_data(
df: DataFrame[PageMetadata]
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
print('Splitting the DataFrame into training, validation and test')
train_df, val_df, test_df = stratified_split(
df,
train_frac=0.7,
val_frac=0.15,
test_frac=0.15,
)
run = wandb.init(project_name=PROJECT_NAME, name='split-dataset')
split_dataset_artifact = wandb.Artifact('split-dataset-metadata', type='dataset')
train_table = wandb.Table(dataframe=train_df)
val_table = wandb.Table(dataframe=val_df)
test_table = wandb.Table(dataframe=test_df)
split_dataset_artifact.add(train_table, name='train_metadata')
split_dataset_artifact.add(val_table, name='val_metadata')
split_dataset_artifact.add(test_table, name='test_metadata')
run.log_artifact(split_dataset_artifact)
run.finish()
print('Batching and shuffling the datasets')
train_ds = dataset_from_dataframe(train_df)
train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE)
val_ds = dataset_from_dataframe(val_df)
val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE)
test_ds = dataset_from_dataframe(test_df)
test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE)
return train_ds, val_ds, test_ds
def build_image_model(input_shape: ImageInputShape) -> keras.Model:
img_model = models.Sequential([
layers.Input(shape=input_shape),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(512, activation='relu'),
], name='image_classification')
img_model.summary()
return img_model
def build_text_model() -> keras.Model:
bert_model = TFBertModel.from_pretrained(BERT_BASE)
input_ids = layers.Input(
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='input_ids'
)
attention_mask = layers.Input(
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='attention_mask'
)
# The second element of the BERT output is the pooled output i.e. the
# representation of the [CLS] token
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]
text_model = models.Model(
inputs=[input_ids, attention_mask],
outputs=outputs,
name='bert'
)
text_model.summary()
return text_model
def build_multimodal_model(
num_classes: int,
img_input_shape: ImageInputShape
) -> keras.Model:
img_model = build_image_model(img_input_shape)
text_model = build_text_model()
img_input = layers.Input(shape=img_input_shape, name='img_input')
text_input_ids = layers.Input(
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_ids'
)
text_input_mask = layers.Input(
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_mask'
)
img_features = img_model(img_input)
text_features = text_model([text_input_ids, text_input_mask])
classification_layers = keras.Sequential([
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(num_classes, activation='softmax'),
], name='classification_layers')
concat_features = layers.concatenate([img_features, text_features],
name='concatenate_features')
outputs = classification_layers(concat_features)
multimodal_model = models.Model(
inputs=[img_input, text_input_ids, text_input_mask],
outputs=outputs,
name='multimodal_document_page_classifier'
)
return multimodal_model
def train():
metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH)
median_height = int(metadata_df['height'].median())
median_width = int(metadata_df['width'].median())
img_size: ImageSize = (median_height, median_width)
img_input_shape: ImageInputShape = img_size + (3,)
label_names: List[str] = sorted(
[d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()]
)
num_classes = len(label_names)
train_ds, val_ds, test_ds = prepare_data(metadata_df)
multimodal_model = build_multimodal_model(num_classes, img_input_shape)
multimodal_model.summary()
multimodal_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
multimodal_model.fit(
train_ds,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
validation_data=val_ds,
)
if __name__ = '__main__':
train()
evaluate()