Spaces:
Sleeping
Sleeping
import pandas as pd | |
import tensorflow as tf | |
import tf_keras as keras | |
from constants import (PROCESSED_DATA_DIR, | |
METADATA_FILEPATH, | |
BATCH_SIZE, | |
EPOCHS, | |
BERT_BASE, | |
MAX_SEQUENCE_LENGHT, | |
PROJECT_NAME, | |
FilePath, | |
PageMetadata, | |
ImageSize, | |
ImageInputShape) | |
from pandera.typing import DataFrame | |
from typing import Tuple, List | |
from transformers import TFBertModel | |
from tf_keras import layers, models | |
from PIL import Image | |
# Allow for unlimited image size, some documents are pretty big... | |
Image.MAX_IMAGE_PIXELS = None | |
def stratified_split( | |
df: pd.DataFrame, | |
train_frac: float, | |
val_frac: float, | |
test_frac: float, | |
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
train_dfs, val_dfs, test_dfs = [], [], [] | |
for label, group in df.groupby('label'): | |
n = len(group) | |
train_end = int(n * train_frac) | |
val_end = train_end + int(n * val_frac) | |
train_dfs.append(group.iloc[:train_end]) | |
val_dfs.append(group.iloc[train_end:val_end]) | |
test_dfs.append(group.iloc[val_end:]) | |
train_df = pd.concat(train_dfs).reset_index(drop=True) | |
val_df = pd.concat(val_dfs).reset_index(drop=True) | |
test_df = pd.concat(test_dfs).reset_index(drop=True) | |
return train_df, val_df, test_df | |
def dataset_from_dataframe(df: pd.DataFrame) -> tf.data.Dataset: | |
return tf.data.Dataset.from_tensor_slices(( | |
df['img_filepath'].values, | |
df['input_ids'].values, | |
df['attention_mask'].values, | |
df['label'].values, | |
)) | |
def load_image(image_path: FilePath, image_size: ImageSize) -> Image: | |
img_width, img_height = image_size | |
# Load image | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_jpeg(image, channels=3) | |
image = tf.image.resize(image, [img_width, img_height]) | |
image /= 255.0 | |
return image | |
def prepare_dataset( | |
ds: tf.data.Dataset, | |
image_size: ImageSize, | |
batch_size=32, | |
buffer_size=1000 | |
) -> tf.data.Dataset: | |
def load_image_and_format_tensor_shape( | |
img_path: FilePath, | |
input_ids: List[int], | |
attention_mask: List[int], | |
label: str | |
): | |
image = load_image(img_path, image_size) | |
return ((image, input_ids, attention_mask), label) | |
return ds.map( | |
load_image_and_format_tensor_shape, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE, | |
) \ | |
.shuffle(buffer_size=buffer_size) \ | |
.batch(batch_size) \ | |
.prefetch(tf.data.experimental.AUTOTUNE) | |
def prepare_data( | |
df: DataFrame[PageMetadata] | |
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: | |
print('Splitting the DataFrame into training, validation and test') | |
train_df, val_df, test_df = stratified_split( | |
df, | |
train_frac=0.7, | |
val_frac=0.15, | |
test_frac=0.15, | |
) | |
run = wandb.init(project_name=PROJECT_NAME, name='split-dataset') | |
split_dataset_artifact = wandb.Artifact('split-dataset-metadata', type='dataset') | |
train_table = wandb.Table(dataframe=train_df) | |
val_table = wandb.Table(dataframe=val_df) | |
test_table = wandb.Table(dataframe=test_df) | |
split_dataset_artifact.add(train_table, name='train_metadata') | |
split_dataset_artifact.add(val_table, name='val_metadata') | |
split_dataset_artifact.add(test_table, name='test_metadata') | |
run.log_artifact(split_dataset_artifact) | |
run.finish() | |
print('Batching and shuffling the datasets') | |
train_ds = dataset_from_dataframe(train_df) | |
train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE) | |
val_ds = dataset_from_dataframe(val_df) | |
val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE) | |
test_ds = dataset_from_dataframe(test_df) | |
test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE) | |
return train_ds, val_ds, test_ds | |
def build_image_model(input_shape: ImageInputShape) -> keras.Model: | |
img_model = models.Sequential([ | |
layers.Input(shape=input_shape), | |
layers.Conv2D(32, (3, 3), activation='relu'), | |
layers.MaxPooling2D((2, 2)), | |
layers.Conv2D(64, (3, 3), activation='relu'), | |
layers.MaxPooling2D((2, 2)), | |
layers.Conv2D(128, (3, 3), activation='relu'), | |
layers.MaxPooling2D((2, 2)), | |
layers.Conv2D(128, (3, 3), activation='relu'), | |
layers.MaxPooling2D((2, 2)), | |
layers.Flatten(), | |
layers.Dense(512, activation='relu'), | |
], name='image_classification') | |
img_model.summary() | |
return img_model | |
def build_text_model() -> keras.Model: | |
bert_model = TFBertModel.from_pretrained(BERT_BASE) | |
input_ids = layers.Input( | |
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='input_ids' | |
) | |
attention_mask = layers.Input( | |
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='attention_mask' | |
) | |
# The second element of the BERT output is the pooled output i.e. the | |
# representation of the [CLS] token | |
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)[1] | |
text_model = models.Model( | |
inputs=[input_ids, attention_mask], | |
outputs=outputs, | |
name='bert' | |
) | |
text_model.summary() | |
return text_model | |
def build_multimodal_model( | |
num_classes: int, | |
img_input_shape: ImageInputShape | |
) -> keras.Model: | |
img_model = build_image_model(img_input_shape) | |
text_model = build_text_model() | |
img_input = layers.Input(shape=img_input_shape, name='img_input') | |
text_input_ids = layers.Input( | |
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_ids' | |
) | |
text_input_mask = layers.Input( | |
shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_mask' | |
) | |
img_features = img_model(img_input) | |
text_features = text_model([text_input_ids, text_input_mask]) | |
classification_layers = keras.Sequential([ | |
tf.keras.layers.Dense(512, activation='relu'), | |
tf.keras.layers.Dense(num_classes, activation='softmax'), | |
], name='classification_layers') | |
concat_features = layers.concatenate([img_features, text_features], | |
name='concatenate_features') | |
outputs = classification_layers(concat_features) | |
multimodal_model = models.Model( | |
inputs=[img_input, text_input_ids, text_input_mask], | |
outputs=outputs, | |
name='multimodal_document_page_classifier' | |
) | |
return multimodal_model | |
def train(): | |
metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH) | |
median_height = int(metadata_df['height'].median()) | |
median_width = int(metadata_df['width'].median()) | |
img_size: ImageSize = (median_height, median_width) | |
img_input_shape: ImageInputShape = img_size + (3,) | |
label_names: List[str] = sorted( | |
[d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()] | |
) | |
num_classes = len(label_names) | |
train_ds, val_ds, test_ds = prepare_data(metadata_df) | |
multimodal_model = build_multimodal_model(num_classes, img_input_shape) | |
multimodal_model.summary() | |
multimodal_model.compile( | |
optimizer='adam', | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy'] | |
) | |
multimodal_model.fit( | |
train_ds, | |
epochs=EPOCHS, | |
batch_size=BATCH_SIZE, | |
validation_data=val_ds, | |
) | |
if __name__ = '__main__': | |
train() | |
evaluate() | |