resolverkatla's picture
Update app.py
2384571 verified
raw
history blame
1.45 kB
import streamlit as st
import os
from PIL import Image
import torch
from torchvision import transforms
from models.cnn import CNNModel
from utils.transforms import get_transforms
os.environ["STREAMLIT_ROOT"] = "/tmp/.streamlit"
@st.cache_resource
def load_model(model_path='saved_models/cnn_model.pth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
class_names = checkpoint['class_names']
model = CNNModel(num_classes=len(class_names))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, class_names, device
st.title("๐Ÿ“ธ Intel Image Classification")
st.write("Upload an image to classify it into one of the image categories: buildings, forest, glacier, mountain, sea, or street.")
model, class_names, device = load_model()
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
transform = get_transforms(train=False)
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
predicted_idx = torch.argmax(output, 1).item()
predicted_class = class_names[predicted_idx]
st.success(f"Predicted class: {predicted_class}")