resolverkatla's picture
Upload 11 files
a3fdab1 verified
raw
history blame
1.43 kB
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms
from models.cnn import CNNModel
from utils.transforms import get_transforms
@st.cache_resource
def load_model(model_path='saved_models/cnn_model.pth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
class_names = checkpoint['class_names']
model = CNNModel(num_classes=len(class_names))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, class_names, device
st.title("๐Ÿ“ธ Intel Image Classification")
st.write("Upload an image to classify it into one of the image categories: buildings, forest, glacier, mountain, sea, or street.")
model, class_names, device = load_model()
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
transform = get_transforms(train=False)
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
predicted_idx = torch.argmax(output, 1).item()
predicted_class = class_names[predicted_idx]
st.success(f"Predicted class: {predicted_class}")