import streamlit as st import sys import os import shutil import time from datetime import datetime import csv import cv2 import numpy as np from PIL import Image import torch # Adjust import paths as needed sys.path.append('Utils') sys.path.append('model') from model.CBAM.reunet_cbam import reunet_cbam from model.transform import transforms from model.unet import UNET from Utils.area import pixel_to_sqft, process_and_overlay_image from Utils.convert import read_pansharpened_rgb # Define base directory for Hugging Face Spaces BASE_DIR = "/Data" # Define subdirectories UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images") MASK_DIR = os.path.join(BASE_DIR, "generated_masks") PATCHES_DIR = os.path.join(BASE_DIR, "patches") PRED_PATCHES_DIR = os.path.join(BASE_DIR, "pred_patches") CSV_LOG_PATH = os.path.join(BASE_DIR, "image_log.csv") # Create directories for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]: os.makedirs(directory, exist_ok=True) # Load model @st.cache_resource def load_model(): model = reunet_cbam() model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict']) model.eval() return model model = load_model() def predict(image): with torch.no_grad(): output = model(image.unsqueeze(0)) return output.squeeze().cpu().numpy() def split_image(image, patch_size=512): h, w, _ = image.shape patches = [] for y in range(0, h, patch_size): for x in range(0, w, patch_size): patch = image[y:min(y+patch_size, h), x:min(x+patch_size, w)] patches.append((f"patch_{y}_{x}.png", patch)) return patches def merge(patch_folder, dest_image='out.png', image_shape=None): merged = np.zeros(image_shape[:-1] + (3,), dtype=np.uint8) for filename in os.listdir(patch_folder): if filename.endswith(".png"): patch_path = os.path.join(patch_folder, filename) patch = cv2.imread(patch_path) patch_height, patch_width, _ = patch.shape # Extract patch coordinates from filename parts = filename.split("_") x, y = None, None for part in parts: if part.endswith(".png"): x = int(part.split(".")[0]) elif part.isdigit(): y = int(part) if x is None or y is None: raise ValueError(f"Invalid filename: {filename}") # Check if patch fits within image boundaries if x + patch_width > image_shape[1] or y + patch_height > image_shape[0]: # Adjust patch position to fit within image boundaries if x + patch_width > image_shape[1]: x = image_shape[1] - patch_width if y + patch_height > image_shape[0]: y = image_shape[0] - patch_height # Merge patch into the main image merged[y:y+patch_height, x:x+patch_width, :] = patch cv2.imwrite(dest_image, merged) return merged def process_large_image(model, image_path, patch_size=512): # Read the image img = cv2.imread(image_path) if img is None: raise ValueError(f"Failed to read image from {image_path}") h, w, _ = img.shape st.write(f"Processing image of size {w}x{h}") # Split the image into patches patches = split_image(img, patch_size) # Process each patch for filename, patch in patches: patch_pil = Image.fromarray(cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)) patch_transformed = transforms(patch_pil) prediction = predict(patch_transformed) mask = (prediction > 0.5).astype(np.uint8) * 255 # Save the mask patch mask_filepath = os.path.join(PRED_PATCHES_DIR, filename) cv2.imwrite(mask_filepath, mask) # Merge the predicted patches merged_mask = merge(PRED_PATCHES_DIR, dest_image='merged_mask.png', image_shape=img.shape) return merged_mask def log_image_details(image_id, image_filename, mask_filename): file_exists = os.path.exists(CSV_LOG_PATH) current_time = datetime.now() date = current_time.strftime('%Y-%m-%d') time = current_time.strftime('%H:%M:%S') with open(CSV_LOG_PATH, mode='a', newline='') as file: writer = csv.writer(file) if not file_exists: writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) # Get the next S.No if file_exists: with open(CSV_LOG_PATH, mode='r') as f: reader = csv.reader(f) sno = sum(1 for row in reader) else: sno = 1 writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) def upload_page(): if 'file_uploaded' not in st.session_state: st.session_state.file_uploaded = False if 'filename' not in st.session_state: st.session_state.filename = None if 'mask_filename' not in st.session_state: st.session_state.mask_filename = None image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) if image is not None and not st.session_state.file_uploaded: try: bytes_data = image.getvalue() timestamp = int(time.time()) original_filename = image.name file_extension = os.path.splitext(original_filename)[1].lower() if file_extension in ['.tiff', '.tif']: filename = f"image_{timestamp}.tif" converted_filename = f"image_{timestamp}_converted.png" else: filename = f"image_{timestamp}.png" converted_filename = filename filepath = os.path.join(UPLOAD_DIR, filename) converted_filepath = os.path.join(UPLOAD_DIR, converted_filename) with open(filepath, "wb") as f: f.write(bytes_data) st.success(f"Image saved to {filepath}") # Check if the uploaded file is a GeoTIFF if file_extension in ['.tiff', '.tif']: st.info('Processing GeoTIFF image...') rgb_image = read_pansharpened_rgb(filepath) cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}') img = Image.open(converted_filepath) else: img = Image.open(filepath) st.image(img, caption='Uploaded Image', use_column_width=True) st.success(f'Image processed and saved as {converted_filename}') # Store the full path of the converted image st.session_state.filename = converted_filename # Process the image st.write("Processing image...") with st.spinner('Analyzing...'): full_mask = process_large_image(model, converted_filepath) # Save the full mask mask_filename = f"mask_{timestamp}.png" mask_filepath = os.path.join(MASK_DIR, mask_filename) cv2.imwrite(mask_filepath, full_mask) st.session_state.mask_filename = mask_filename st.success("Image processed successfully") # Log image details log_image_details(timestamp, converted_filename, mask_filename) st.session_state.file_uploaded = True # Clean up temporary patch files st.info('Cleaning up temporary files...') for file in os.listdir(PRED_PATCHES_DIR): os.remove(os.path.join(PRED_PATCHES_DIR, file)) st.success('Temporary files cleaned up') except Exception as e: st.error(f"An error occurred: {str(e)}") st.error("Please check the logs for more details.") print(f"Error in upload_page: {str(e)}") # This will appear in the Streamlit logs if st.session_state.file_uploaded and st.button('View result'): if st.session_state.filename is None: st.error("Please upload an image before viewing the result.") else: st.success('Image analyzed') st.session_state.page = 'result' st.rerun() def result_page(): st.title('Analysis Result') if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: st.error("No image or mask file found. Please upload and process an image first.") if st.button('Back to Upload'): st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() return col1, col2 = st.columns(2) # Display original image original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) if os.path.exists(original_img_path): original_img = Image.open(original_img_path) col1.image(original_img, caption='Original Image', use_column_width=True) else: col1.error(f"Original image file not found: {original_img_path}") # Display predicted mask mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename) if os.path.exists(mask_path): mask = Image.open(mask_path) col2.image(mask, caption='Predicted Mask', use_column_width=True) else: col2.error(f"Predicted mask file not found: {mask_path}") st.subheader("Overlay with Area of Buildings (sqft)") # Display overlayed image if os.path.exists(original_img_path) and os.path.exists(mask_path): original_np = cv2.imread(original_img_path) mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Ensure mask is binary _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) # Resize mask to match original image size if necessary if original_np.shape[:2] != mask_np.shape[:2]: mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) # Process and overlay image overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') st.image(overlay_img, caption='Overlay Image', use_column_width=True) else: st.error("Image or mask file not found for overlay.") if st.button('Back to Upload'): st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() def main(): st.title('Building area estimation') if 'page' not in st.session_state: st.session_state.page = 'upload' if st.session_state.page == 'upload': upload_page() elif st.session_state.page == 'result': result_page() if __name__ == '__main__': main()