import streamlit as st import tensorflow_datasets as tfds import tensorflow as tf import numpy as np import time import tensorflow.keras as keras from tensorflow.keras.applications import VGG16 from tensorflow.keras.layers import Dense, Flatten from tensorflow.keras.models import Model, load_model from datasets import load_dataset import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns from huggingface_hub import HfApi import os # πŸ“Œ Percorso della cache os.environ["HF_HOME"] = "/app/.cache" os.environ["HF_DATASETS_CACHE"] = "/app/.cache" HF_TOKEN = os.getenv("HF_TOKEN") # πŸ“Œ Autenticazione Hugging Face if HF_TOKEN: api = HfApi() user_info = api.whoami(HF_TOKEN) st.write(f"βœ… Autenticato come {user_info.get('name', 'Utente sconosciuto')}") else: st.warning("⚠️ Nessun token API trovato! Verifica il Secret nello Space.") # πŸ“Œ Caricamento del dataset st.write("πŸ”„ Caricamento di 300 immagini da `tiny-imagenet`...") dataset = load_dataset("zh-plus/tiny-imagenet", split="train") image_list = [] label_list = [] for i, sample in enumerate(dataset): if i >= 300: # Prende solo 300 immagini break image = tf.image.resize(sample["image"], (64, 64)) / 255.0 # Normalizzazione image_list.append(image.numpy()) label_list.append(np.array(sample["label"])) X = np.array(image_list) y = np.array(label_list) # πŸ“Œ Suddivisione dataset: 80% training, 20% validation X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) st.write(f"πŸ“Š **Training:** {X_train.shape[0]} immagini") st.write(f"πŸ“Š **Validation:** {X_val.shape[0]} immagini") # πŸ“Œ Caricamento del modello if os.path.exists("Silva.h5"): model = load_model("Silva.h5") st.write("βœ… Modello `Silva.h5` caricato, nessun nuovo training necessario!") else: st.write("πŸš€ Training in corso perchΓ© `Silva.h5` non esiste...") base_model = VGG16(weights="imagenet", include_top=False, input_shape=(64, 64, 3)) for layer in base_model.layers: layer.trainable = False x = Flatten()(base_model.output) x = Dense(256, activation="relu")(x) x = Dense(128, activation="relu")(x) output = Dense(len(set(y_train)), activation="softmax")(x) model = Model(inputs=base_model.input, outputs=output) model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # πŸ“Œ Training con monitoraggio validazione history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val)) st.write("βœ… Addestramento completato!") # πŸ“Œ Salvare il modello model.save("Silva.h5") st.write("βœ… Modello salvato come `Silva.h5`!") # πŸ“Œ Calcolo delle metriche sulla validazione y_pred_val = np.argmax(model.predict(X_val), axis=1) accuracy_val = np.mean(y_pred_val == y_val) rmse_val = np.sqrt(np.mean((y_pred_val - y_val) ** 2)) report_val = classification_report(y_val, y_pred_val, output_dict=True) recall_val = report_val["weighted avg"]["recall"] precision_val = report_val["weighted avg"]["precision"] f1_score_val = report_val["weighted avg"]["f1-score"] st.write(f"πŸ“Š **Validation Accuracy:** {accuracy_val:.4f}") st.write(f"πŸ“Š **Validation RMSE:** {rmse_val:.4f}") st.write(f"πŸ“Š **Validation Precision:** {precision_val:.4f}") st.write(f"πŸ“Š **Validation Recall:** {recall_val:.4f}") st.write(f"πŸ“Š **Validation F1-Score:** {f1_score_val:.4f}") # πŸ“Œ Bottone per generare la matrice di confusione sulla validazione if st.button("πŸ”Ž Genera matrice di confusione per validazione"): conf_matrix_val = confusion_matrix(y_val, y_pred_val) fig, ax = plt.subplots(figsize=(10, 7)) sns.heatmap(conf_matrix_val, annot=True, cmap="Blues", fmt="d", ax=ax) st.pyplot(fig) st.write("βœ… Matrice di confusione generata!") # πŸ“Œ Grafico per Loss e Accuracy con validazione fig, ax = plt.subplots(1, 2, figsize=(12, 5)) ax[0].plot(history.history["loss"], label="Training Loss") ax[0].plot(history.history["val_loss"], label="Validation Loss") ax[1].plot(history.history["accuracy"], label="Training Accuracy") ax[1].plot(history.history["val_accuracy"], label="Validation Accuracy") ax[0].set_title("Loss durante il training e validazione") ax[1].set_title("Accuracy durante il training e validazione") ax[0].legend() ax[1].legend() st.pyplot(fig) # πŸ“Œ Bottone per scaricare il modello if os.path.exists("Silva.h5"): with open("Silva.h5", "rb") as f: st.download_button( label="πŸ“₯ Scarica il modello Silva.h5", data=f, file_name="Silva.h5", mime="application/octet-stream" ) # πŸ“Œ Bottone per caricare il modello su Hugging Face def upload_model(): api.upload_file( path_or_fileobj="Silva.h5", path_in_repo="Silva.h5", repo_id="scontess/trainigVVG16", repo_type="space" ) st.success("βœ… Modello 'Silva.h5' caricato su Hugging Face!") st.write("πŸ“₯ Carica il modello Silva su Hugging Face") if st.button("πŸš€ Carica Silva su Model Store"): upload_model()