frogleo commited on
Commit
4eacf35
·
1 Parent(s): a278c66

着手增加逻辑

Browse files
Files changed (1) hide show
  1. app.py +72 -16
app.py CHANGED
@@ -1,11 +1,20 @@
1
- import spaces
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
 
5
  import random
6
- import logging
7
  import utils
 
 
 
8
  from diffusers.models import AutoencoderKL
 
 
 
9
  from config import (
10
  MODEL,
11
  MIN_IMAGE_SIZE,
@@ -52,6 +61,37 @@ else:
52
  pipe = None
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @spaces.GPU
57
  def generate(
@@ -70,23 +110,39 @@ def generate(
70
  ):
71
  if randomize_seed:
72
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
73
 
74
- # generator = torch.Generator().manual_seed(seed)
75
-
76
- # image = pipe(
77
- # prompt=prompt,
78
- # negative_prompt=negative_prompt,
79
- # guidance_scale=guidance_scale,
80
- # num_inference_steps=num_inference_steps,
81
- # width=width,
82
- # height=height,
83
- # generator=generator,
84
- # ).images[0]
85
 
86
- # return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- return None, seed
89
 
 
90
 
91
 
92
 
@@ -207,7 +263,7 @@ with gr.Blocks(css=custom_css).queue() as demo:
207
  seed,randomize_seed,
208
  guidance_scale,num_inference_steps
209
  ],
210
- outputs=[result, seed],
211
  )
212
 
213
  if __name__ == "__main__":
 
1
+ import os
2
+ import gc
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
+ import json
7
+ import spaces
8
  import random
9
+ import config
10
  import utils
11
+ import logging
12
+ from PIL import Image, PngImagePlugin
13
+ from datetime import datetime
14
  from diffusers.models import AutoencoderKL
15
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
+ import time
17
+ from typing import List, Dict, Tuple, Optional
18
  from config import (
19
  MODEL,
20
  MIN_IMAGE_SIZE,
 
61
  pipe = None
62
 
63
 
64
+ class GenerationError(Exception):
65
+ """Custom exception for generation errors"""
66
+ pass
67
+
68
+ def validate_prompt(prompt: str) -> str:
69
+ """Validate and clean up the input prompt."""
70
+ if not isinstance(prompt, str):
71
+ raise GenerationError("Prompt must be a string")
72
+ try:
73
+ # Ensure proper UTF-8 encoding/decoding
74
+ prompt = prompt.encode('utf-8').decode('utf-8')
75
+ # Add space between ! and ,
76
+ prompt = prompt.replace("!,", "! ,")
77
+ except UnicodeError:
78
+ raise GenerationError("Invalid characters in prompt")
79
+
80
+ # Only check if the prompt is completely empty or only whitespace
81
+ if not prompt or prompt.isspace():
82
+ raise GenerationError("Prompt cannot be empty")
83
+ return prompt.strip()
84
+
85
+ def validate_dimensions(width: int, height: int) -> None:
86
+ """Validate image dimensions."""
87
+ if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE:
88
+ raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
89
+
90
+ if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE:
91
+ raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
92
+
93
+
94
+
95
 
96
  @spaces.GPU
97
  def generate(
 
110
  ):
111
  if randomize_seed:
112
  seed = random.randint(0, MAX_SEED)
113
+
114
+ """Generate images based on the given parameters."""
115
+ start_time = time.time()
116
+ upscaler_pipe = None
117
+ backup_scheduler = None
118
 
119
+ try:
120
+ # Memory management
121
+ torch.cuda.empty_cache()
122
+ gc.collect()
 
 
 
 
 
 
 
123
 
124
+ return None
125
+ except GenerationError as e:
126
+ logger.warning(f"Generation validation error: {str(e)}")
127
+ raise gr.Error(str(e))
128
+ except Exception as e:
129
+ logger.exception("Unexpected error during generation")
130
+ raise gr.Error(f"Generation failed: {str(e)}")
131
+ finally:
132
+ # Cleanup
133
+ torch.cuda.empty_cache()
134
+ gc.collect()
135
+
136
+ if upscaler_pipe is not None:
137
+ del upscaler_pipe
138
+
139
+ if backup_scheduler is not None and pipe is not None:
140
+ pipe.scheduler = backup_scheduler
141
+
142
+ utils.free_memory()
143
 
 
144
 
145
+
146
 
147
 
148
 
 
263
  seed,randomize_seed,
264
  guidance_scale,num_inference_steps
265
  ],
266
+ outputs=[result],
267
  )
268
 
269
  if __name__ == "__main__":