import gradio as gr
import kornia as K
from kornia.core import Tensor
from kornia.contrib import ImageStitcher
import kornia.feature as KF
import torch
import numpy as np

def preprocess_image(img):
    print(f"Input image type: {type(img)}")
    print(f"Input image shape: {img.shape if hasattr(img, 'shape') else 'No shape attribute'}")
    
    # Convert numpy array to Tensor and ensure correct shape
    if isinstance(img, np.ndarray):
        img = K.image_to_tensor(img, keepdim=False).float() / 255.0
    elif isinstance(img, torch.Tensor):
        img = img.float()
        if img.max() > 1.0:
            img = img / 255.0
    else:
        raise ValueError(f"Unsupported image type: {type(img)}")
    
    print(f"After conversion to tensor - shape: {img.shape}")
    
    # Ensure 4D tensor (B, C, H, W)
    if img.ndim == 2:
        img = img.unsqueeze(0).unsqueeze(0)
    elif img.ndim == 3:
        if img.shape[0] in [1, 3]:
            img = img.unsqueeze(0)
        else:
            img = img.unsqueeze(1)
    elif img.ndim == 4:
        if img.shape[1] not in [1, 3]:
            img = img.permute(0, 3, 1, 2)
    
    print(f"After ensuring 4D - shape: {img.shape}")
    
    # Ensure 3 channel image
    if img.shape[1] == 1:
        img = img.repeat(1, 3, 1, 1)
    elif img.shape[1] > 3:
        img = img[:, :3]  # Take only the first 3 channels if more than 3
    
    print(f"Final tensor shape: {img.shape}")
    return img

def inference(img_1, img_2):
    # Preprocess images
    img_1 = preprocess_image(img_1)
    img_2 = preprocess_image(img_2)
    
    IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac')
    with torch.no_grad():
        result = IS(img_1, img_2)
        
    return K.tensor_to_image(result[0])

examples = [
    ['examples/foto1B.jpg', 'examples/foto1A.jpg'],
]

with gr.Blocks(theme='huggingface') as demo_app:
    gr.Markdown("# Image Stitching using Kornia and LoFTR")
    with gr.Row():
        input_image1 = gr.Image(label="Input Image 1")
        input_image2 = gr.Image(label="Input Image 2")
    output_image = gr.Image(label="Output Image")
    stitch_button = gr.Button("Stitch Images")
    stitch_button.click(fn=inference, inputs=[input_image1, input_image2], outputs=output_image)
    gr.Examples(examples=examples, inputs=[input_image1, input_image2])

if __name__ == "__main__":
    demo_app.launch(share=True)