import pandas as pd import tensorflow as tf import tf_keras as keras from constants import (PROCESSED_DATA_DIR, METADATA_FILEPATH, BATCH_SIZE, EPOCHS, BERT_BASE, MAX_SEQUENCE_LENGHT, PROJECT_NAME, FilePath, PageMetadata, ImageSize, ImageInputShape) from pandera.typing import DataFrame from typing import Tuple, List from transformers import TFBertModel from tf_keras import layers, models from PIL import Image # Allow for unlimited image size, some documents are pretty big... Image.MAX_IMAGE_PIXELS = None def stratified_split( df: pd.DataFrame, train_frac: float, val_frac: float, test_frac: float, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: train_dfs, val_dfs, test_dfs = [], [], [] for label, group in df.groupby('label'): n = len(group) train_end = int(n * train_frac) val_end = train_end + int(n * val_frac) train_dfs.append(group.iloc[:train_end]) val_dfs.append(group.iloc[train_end:val_end]) test_dfs.append(group.iloc[val_end:]) train_df = pd.concat(train_dfs).reset_index(drop=True) val_df = pd.concat(val_dfs).reset_index(drop=True) test_df = pd.concat(test_dfs).reset_index(drop=True) return train_df, val_df, test_df def dataset_from_dataframe(df: pd.DataFrame) -> tf.data.Dataset: return tf.data.Dataset.from_tensor_slices(( df['img_filepath'].values, df['input_ids'].values, df['attention_mask'].values, df['label'].values, )) def load_image(image_path: FilePath, image_size: ImageSize) -> Image: img_width, img_height = image_size # Load image image = tf.io.read_file(image_path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, [img_width, img_height]) image /= 255.0 return image def prepare_dataset( ds: tf.data.Dataset, image_size: ImageSize, batch_size=32, buffer_size=1000 ) -> tf.data.Dataset: def load_image_and_format_tensor_shape( img_path: FilePath, input_ids: List[int], attention_mask: List[int], label: str ): image = load_image(img_path, image_size) return ((image, input_ids, attention_mask), label) return ds.map( load_image_and_format_tensor_shape, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) \ .shuffle(buffer_size=buffer_size) \ .batch(batch_size) \ .prefetch(tf.data.experimental.AUTOTUNE) def prepare_data( df: DataFrame[PageMetadata] ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: print('Splitting the DataFrame into training, validation and test') train_df, val_df, test_df = stratified_split( df, train_frac=0.7, val_frac=0.15, test_frac=0.15, ) run = wandb.init(project_name=PROJECT_NAME, name='split-dataset') split_dataset_artifact = wandb.Artifact('split-dataset-metadata', type='dataset') train_table = wandb.Table(dataframe=train_df) val_table = wandb.Table(dataframe=val_df) test_table = wandb.Table(dataframe=test_df) split_dataset_artifact.add(train_table, name='train_metadata') split_dataset_artifact.add(val_table, name='val_metadata') split_dataset_artifact.add(test_table, name='test_metadata') run.log_artifact(split_dataset_artifact) run.finish() print('Batching and shuffling the datasets') train_ds = dataset_from_dataframe(train_df) train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE) val_ds = dataset_from_dataframe(val_df) val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE) test_ds = dataset_from_dataframe(test_df) test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE) return train_ds, val_ds, test_ds def build_image_model(input_shape: ImageInputShape) -> keras.Model: img_model = models.Sequential([ layers.Input(shape=input_shape), layers.Conv2D(32, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Flatten(), layers.Dense(512, activation='relu'), ], name='image_classification') img_model.summary() return img_model def build_text_model() -> keras.Model: bert_model = TFBertModel.from_pretrained(BERT_BASE) input_ids = layers.Input( shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='input_ids' ) attention_mask = layers.Input( shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='attention_mask' ) # The second element of the BERT output is the pooled output i.e. the # representation of the [CLS] token outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)[1] text_model = models.Model( inputs=[input_ids, attention_mask], outputs=outputs, name='bert' ) text_model.summary() return text_model def build_multimodal_model( num_classes: int, img_input_shape: ImageInputShape ) -> keras.Model: img_model = build_image_model(img_input_shape) text_model = build_text_model() img_input = layers.Input(shape=img_input_shape, name='img_input') text_input_ids = layers.Input( shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_ids' ) text_input_mask = layers.Input( shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_mask' ) img_features = img_model(img_input) text_features = text_model([text_input_ids, text_input_mask]) classification_layers = keras.Sequential([ tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(num_classes, activation='softmax'), ], name='classification_layers') concat_features = layers.concatenate([img_features, text_features], name='concatenate_features') outputs = classification_layers(concat_features) multimodal_model = models.Model( inputs=[img_input, text_input_ids, text_input_mask], outputs=outputs, name='multimodal_document_page_classifier' ) return multimodal_model def train(): metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH) median_height = int(metadata_df['height'].median()) median_width = int(metadata_df['width'].median()) img_size: ImageSize = (median_height, median_width) img_input_shape: ImageInputShape = img_size + (3,) label_names: List[str] = sorted( [d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()] ) num_classes = len(label_names) train_ds, val_ds, test_ds = prepare_data(metadata_df) multimodal_model = build_multimodal_model(num_classes, img_input_shape) multimodal_model.summary() multimodal_model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) multimodal_model.fit( train_ds, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=val_ds, ) if __name__ = '__main__': train() evaluate()