import streamlit as st import os from PIL import Image import torch from torchvision import transforms from models.cnn import CNNModel from utils.transforms import get_transforms os.environ["STREAMLIT_ROOT"] = "/tmp/.streamlit" @st.cache_resource def load_model(model_path='saved_models/cnn_model.pth'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(model_path, map_location=device) class_names = checkpoint['class_names'] model = CNNModel(num_classes=len(class_names)) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() return model, class_names, device st.title("📸 Intel Image Classification") st.write("Upload an image to classify it into one of the image categories: buildings, forest, glacier, mountain, sea, or street.") model, class_names, device = load_model() uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_container_width=True) transform = get_transforms(train=False) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) predicted_idx = torch.argmax(output, 1).item() predicted_class = class_names[predicted_idx] st.success(f"Predicted class: {predicted_class}")