Spaces:
Runtime error
Runtime error
# import the necessary packages | |
from utilities import config | |
from utilities import model | |
from utilities import visualization | |
from tensorflow import keras | |
import gradio as gr | |
# load the models from disk | |
conv_stem = keras.models.load_model( | |
config.IMAGENETTE_STEM_PATH, | |
compile=False | |
) | |
conv_trunk = keras.models.load_model( | |
config.IMAGENETTE_TRUNK_PATH, | |
compile=False | |
) | |
conv_attn = keras.models.load_model( | |
config.IMAGENETTE_ATTN_PATH, | |
compile=False | |
) | |
# create the patch conv net | |
patch_conv_net = model.PatchConvNet( | |
stem=conv_stem, | |
trunk=conv_trunk, | |
attention_pooling=conv_attn, | |
) | |
# get the plot attention function | |
plot_attention = visualization.PlotAttention(model=patch_conv_net) | |
iface = gr.Interface( | |
fn=plot_attention, | |
inputs=[gr.inputs.Image(label="Input Image")], | |
outputs="image").launch() |