patch-conv-net / app.py
ariG23498's picture
ariG23498 HF staff
fix: reformat code for imagenette
310a06c
raw
history blame
809 Bytes
# import the necessary packages
from utilities import config
from utilities import model
from utilities import visualization
from tensorflow import keras
import gradio as gr
# load the models from disk
conv_stem = keras.models.load_model(
config.IMAGENETTE_STEM_PATH,
compile=False
)
conv_trunk = keras.models.load_model(
config.IMAGENETTE_TRUNK_PATH,
compile=False
)
conv_attn = keras.models.load_model(
config.IMAGENETTE_ATTN_PATH,
compile=False
)
# create the patch conv net
patch_conv_net = model.PatchConvNet(
stem=conv_stem,
trunk=conv_trunk,
attention_pooling=conv_attn,
)
# get the plot attention function
plot_attention = visualization.PlotAttention(model=patch_conv_net)
iface = gr.Interface(
fn=plot_attention,
inputs=[gr.inputs.Image(label="Input Image")],
outputs="image").launch()