Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -98,11 +98,11 @@ def clean_chat_history(chat_history):
|
|
| 98 |
# ============================================
|
| 99 |
|
| 100 |
# Environment variables and parameters for Stable Diffusion XL
|
| 101 |
-
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") #
|
| 102 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 103 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 104 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 105 |
-
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For
|
| 106 |
|
| 107 |
# Load the SDXL pipeline
|
| 108 |
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
@@ -113,7 +113,11 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
| 113 |
).to(device)
|
| 114 |
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
| 115 |
|
| 116 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if USE_TORCH_COMPILE:
|
| 118 |
sd_pipe.compile()
|
| 119 |
|
|
@@ -191,16 +195,16 @@ def generate(
|
|
| 191 |
repetition_penalty: float = 1.2,
|
| 192 |
):
|
| 193 |
"""
|
| 194 |
-
Generates chatbot responses with support for multimodal input, TTS, and
|
| 195 |
-
|
| 196 |
-
- "@tts1" or "@tts2"
|
| 197 |
-
- "@image"
|
| 198 |
"""
|
| 199 |
text = input_dict["text"]
|
| 200 |
files = input_dict.get("files", [])
|
| 201 |
|
| 202 |
# ----------------------------
|
| 203 |
-
#
|
| 204 |
# ----------------------------
|
| 205 |
if text.strip().lower().startswith("@image"):
|
| 206 |
# Remove the "@image" tag and use the rest as prompt
|
|
@@ -343,4 +347,5 @@ demo = gr.ChatInterface(
|
|
| 343 |
)
|
| 344 |
|
| 345 |
if __name__ == "__main__":
|
|
|
|
| 346 |
demo.queue(max_size=20).launch(share=True)
|
|
|
|
| 98 |
# ============================================
|
| 99 |
|
| 100 |
# Environment variables and parameters for Stable Diffusion XL
|
| 101 |
+
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
| 102 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
| 103 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 104 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 105 |
+
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
|
| 106 |
|
| 107 |
# Load the SDXL pipeline
|
| 108 |
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
|
| 113 |
).to(device)
|
| 114 |
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
| 115 |
|
| 116 |
+
# **Fix for dtype mismatch in the text encoder:**
|
| 117 |
+
if torch.cuda.is_available():
|
| 118 |
+
sd_pipe.text_encoder = sd_pipe.text_encoder.half()
|
| 119 |
+
|
| 120 |
+
# Optional: compile the model for speedup if enabled
|
| 121 |
if USE_TORCH_COMPILE:
|
| 122 |
sd_pipe.compile()
|
| 123 |
|
|
|
|
| 195 |
repetition_penalty: float = 1.2,
|
| 196 |
):
|
| 197 |
"""
|
| 198 |
+
Generates chatbot responses with support for multimodal input, TTS, and image generation.
|
| 199 |
+
Special commands:
|
| 200 |
+
- "@tts1" or "@tts2": triggers text-to-speech.
|
| 201 |
+
- "@image": triggers image generation using the SDXL pipeline.
|
| 202 |
"""
|
| 203 |
text = input_dict["text"]
|
| 204 |
files = input_dict.get("files", [])
|
| 205 |
|
| 206 |
# ----------------------------
|
| 207 |
+
# IMAGE GENERATION BRANCH
|
| 208 |
# ----------------------------
|
| 209 |
if text.strip().lower().startswith("@image"):
|
| 210 |
# Remove the "@image" tag and use the rest as prompt
|
|
|
|
| 347 |
)
|
| 348 |
|
| 349 |
if __name__ == "__main__":
|
| 350 |
+
# To create a public link, set share=True in launch().
|
| 351 |
demo.queue(max_size=20).launch(share=True)
|