Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,591 Bytes
ece05f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import spaces
import streamlit as st
import torch
from huggingface_hub import snapshot_download
from txt2panoimg import Text2360PanoramaImagePipeline
from img2panoimg import Image2360PanoramaImagePipeline
from PIL import Image
from streamlit_pannellum import streamlit_pannellum
# Custom CSS to make the UI more attractive
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.main {
background-color: #f0f2f6;
}
h1 {
color: #1E3A8A;
text-align: center;
padding: 20px 0;
font-size: 2.5rem;
}
.stTabs {
background-color: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.stButton>button {
background-color: #1E3A8A;
color: white;
font-weight: bold;
}
.viewer-column {
background-color: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
</style>
""", unsafe_allow_html=True)
# Download the model
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
# Initialize pipelines
txt2panoimg = Text2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
img2panoimg = Image2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
# Load the default mask image
default_mask = Image.open("i2p-mask.jpg").convert("RGB")
@spaces.GPU(duration=200)
def text_to_pano(prompt, upscale):
input_data = {'prompt': prompt, 'upscale': upscale}
output = txt2panoimg(input_data)
return output
@spaces.GPU(duration=200)
def image_to_pano(image, mask, prompt, upscale):
image = image.resize((512, 512))
if mask is None:
mask = default_mask.resize((512, 512))
else:
mask = mask.resize((512, 512))
input_data = {
'prompt': prompt,
'image': image,
'mask': mask,
'upscale': upscale
}
output = img2panoimg(input_data)
return output
st.title("360° Panorama Image Generation")
tab1, tab2 = st.tabs(["Text to 360° Panorama", "Image to 360° Panorama"])
# Function to display the panorama viewer
def display_panorama(image):
streamlit_pannellum(
config={
"default": {
"firstScene": "generated",
},
"scenes": {
"generated": {
"title": "Generated Panorama",
"type": "equirectangular",
"panorama": image,
"autoLoad": True,
}
}
}
)
with tab1:
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Input")
t2p_input = st.text_area("Enter your prompt", height=100)
t2p_upscale = st.checkbox("Upscale (requires >16GB GPU)")
generate_button = st.button("Generate Panorama")
with col2:
st.subheader("Output")
output_placeholder = st.empty()
viewer_placeholder = st.empty()
if generate_button:
with st.spinner("Generating your 360° panorama..."):
output = text_to_pano(t2p_input, t2p_upscale)
output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
with viewer_placeholder.container():
display_panorama(output)
with tab2:
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Input")
i2p_image = st.file_uploader("Upload Input Image", type=["png", "jpg", "jpeg"])
i2p_mask = st.file_uploader("Upload Mask Image (Optional)", type=["png", "jpg", "jpeg"])
i2p_prompt = st.text_area("Enter your prompt", height=100)
i2p_upscale = st.checkbox("Upscale (requires >16GB GPU)", key="i2p_upscale")
generate_button = st.button("Generate Panorama", key="i2p_generate")
with col2:
st.subheader("Output")
output_placeholder = st.empty()
viewer_placeholder = st.empty()
if generate_button and i2p_image is not None:
with st.spinner("Generating your 360° panorama..."):
image = Image.open(i2p_image)
mask = Image.open(i2p_mask) if i2p_mask is not None else None
output = image_to_pano(image, mask, i2p_prompt, i2p_upscale)
output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
with viewer_placeholder.container():
display_panorama(output)
elif generate_button and i2p_image is None:
st.error("Please upload an input image.") |