File size: 4,591 Bytes
ece05f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import spaces
import streamlit as st
import torch
from huggingface_hub import snapshot_download
from txt2panoimg import Text2360PanoramaImagePipeline
from img2panoimg import Image2360PanoramaImagePipeline
from PIL import Image
from streamlit_pannellum import streamlit_pannellum

# Custom CSS to make the UI more attractive
st.markdown("""
<style>
    .stApp {
        max-width: 1200px;
        margin: 0 auto;
    }
    .main {
        background-color: #f0f2f6;
    }
    h1 {
        color: #1E3A8A;
        text-align: center;
        padding: 20px 0;
        font-size: 2.5rem;
    }
    .stTabs {
        background-color: white;
        padding: 20px;
        border-radius: 10px;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    }
    .stButton>button {
        background-color: #1E3A8A;
        color: white;
        font-weight: bold;
    }
    .viewer-column {
        background-color: white;
        padding: 20px;
        border-radius: 10px;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    }
</style>
""", unsafe_allow_html=True)

# Download the model
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")

# Initialize pipelines
txt2panoimg = Text2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
img2panoimg = Image2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)

# Load the default mask image
default_mask = Image.open("i2p-mask.jpg").convert("RGB")

@spaces.GPU(duration=200)
def text_to_pano(prompt, upscale):
    input_data = {'prompt': prompt, 'upscale': upscale}
    output = txt2panoimg(input_data)
    return output

@spaces.GPU(duration=200)
def image_to_pano(image, mask, prompt, upscale):
    image = image.resize((512, 512))
    if mask is None:
        mask = default_mask.resize((512, 512))
    else:
        mask = mask.resize((512, 512))
    input_data = {
        'prompt': prompt,
        'image': image,
        'mask': mask,
        'upscale': upscale
    }
    output = img2panoimg(input_data)
    return output

st.title("360° Panorama Image Generation")

tab1, tab2 = st.tabs(["Text to 360° Panorama", "Image to 360° Panorama"])

# Function to display the panorama viewer
def display_panorama(image):
    streamlit_pannellum(
        config={
            "default": {
                "firstScene": "generated",
            },
            "scenes": {
                "generated": {
                    "title": "Generated Panorama",
                    "type": "equirectangular",
                    "panorama": image,
                    "autoLoad": True,
                }
            }
        }
    )

with tab1:
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.subheader("Input")
        t2p_input = st.text_area("Enter your prompt", height=100)
        t2p_upscale = st.checkbox("Upscale (requires >16GB GPU)")
        generate_button = st.button("Generate Panorama")

    with col2:
        st.subheader("Output")
        output_placeholder = st.empty()
        viewer_placeholder = st.empty()

    if generate_button:
        with st.spinner("Generating your 360° panorama..."):
            output = text_to_pano(t2p_input, t2p_upscale)
            output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
            with viewer_placeholder.container():
                display_panorama(output)

with tab2:
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.subheader("Input")
        i2p_image = st.file_uploader("Upload Input Image", type=["png", "jpg", "jpeg"])
        i2p_mask = st.file_uploader("Upload Mask Image (Optional)", type=["png", "jpg", "jpeg"])
        i2p_prompt = st.text_area("Enter your prompt", height=100)
        i2p_upscale = st.checkbox("Upscale (requires >16GB GPU)", key="i2p_upscale")
        generate_button = st.button("Generate Panorama", key="i2p_generate")

    with col2:
        st.subheader("Output")
        output_placeholder = st.empty()
        viewer_placeholder = st.empty()

    if generate_button and i2p_image is not None:
        with st.spinner("Generating your 360° panorama..."):
            image = Image.open(i2p_image)
            mask = Image.open(i2p_mask) if i2p_mask is not None else None
            output = image_to_pano(image, mask, i2p_prompt, i2p_upscale)
            output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
            with viewer_placeholder.container():
                display_panorama(output)
    elif generate_button and i2p_image is None:
        st.error("Please upload an input image.")