|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
from datasets import load_dataset, Image, list_datasets |
|
from PIL import Image |
|
|
|
MODELS = [ |
|
"google/vit-base-patch16-224", |
|
"nateraw/vit-age-classifier" |
|
] |
|
DATASETS = [ |
|
"Nunt/testedata", |
|
"Nunt/backup_leonardo_2024-02-01" |
|
] |
|
MAX_N_LABELS = 5 |
|
SPLIT_TO_CLASSIFY = 'pasta' |
|
|
|
|
|
|
|
|
|
|
|
|
|
COL1="" |
|
COL2="" |
|
COLS = st.columns([3, 1]) |
|
CONTAINER_TOP = st.container() |
|
CONTAINER_BODY = st.container() |
|
CONTAINER_FULL = st.container() |
|
CONTAINER_LOOP = st.container() |
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_one_image(classifier_model, dataset_to_classify): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return "done" |
|
|
|
|
|
|
|
def classify_full_dataset(shosen_dataset_name, chosen_model_name): |
|
image_count = 0 |
|
|
|
|
|
dataset = load_dataset(shosen_dataset_name,"testedata_readme") |
|
with CONTAINER_LOOP: |
|
|
|
image_object = dataset['pasta'][0]["image"] |
|
st.image(image_object, caption="Uploaded Image", width=300) |
|
st.write("### FLAG 3") |
|
|
|
|
|
classifier_pipeline = pipeline('image-classification', model=chosen_model_name) |
|
CONTAINER_LOOP.write("### FLAG 4") |
|
|
|
|
|
classification_result = classifier_pipeline(image_object) |
|
CONTAINER_LOOP.write(classification_result) |
|
CONTAINER_LOOP.write("### FLAG 5") |
|
|
|
|
|
|
|
|
|
image_count += 1 |
|
CONTAINER_LOOP.write(f"Image count: {image_count}") |
|
|
|
return image_count |
|
|
|
|
|
def make_template(): |
|
|
|
|
|
|
|
with CONTAINER_FULL: |
|
CONTAINER_TOP |
|
CONTAINER_BODY |
|
with CONTAINER_BODY: |
|
|
|
with COLS[1]: |
|
CONTAINER_LOOP |
|
|
|
|
|
def main(): |
|
|
|
make_template() |
|
|
|
with CONTAINER_TOP: |
|
st.write("# Bulk Image Classification DEMO") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with CONTAINER_BODY: |
|
|
|
with COL1: |
|
st.markdown("This app uses several 🤗 models to classify images stored in 🤗 datasets.") |
|
st.write("Soon we will have a dataset template") |
|
|
|
|
|
chosen_model_name = st.selectbox("Select the model to use", MODELS, index=0) |
|
if chosen_model_name is not None: |
|
COL1.st.write("You selected", chosen_model_name) |
|
|
|
|
|
shosen_dataset_name = st.selectbox("Select the dataset to use", DATASETS, index=0) |
|
if shosen_dataset_name is not None: |
|
COL1.st.write("You selected", shosen_dataset_name) |
|
|
|
|
|
|
|
if chosen_model_name is not None and shosen_dataset_name is not None: |
|
if COL1.button("Classify images"): |
|
|
|
|
|
classification_result = classify_full_dataset(shosen_dataset_name, chosen_model_name) |
|
CONTAINER_LOOP.write(f"Classification result: {classification_result}") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |