frogleo commited on
Commit
b82c0ac
·
1 Parent(s): 4eacf35

继续完善逻辑

Browse files
__pycache__/config.cpython-310.pyc CHANGED
Binary files a/__pycache__/config.cpython-310.pyc and b/__pycache__/config.cpython-310.pyc differ
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -121,6 +121,25 @@ def generate(
121
  torch.cuda.empty_cache()
122
  gc.collect()
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  return None
125
  except GenerationError as e:
126
  logger.warning(f"Generation validation error: {str(e)}")
 
121
  torch.cuda.empty_cache()
122
  gc.collect()
123
 
124
+ # Input validation
125
+ prompt = validate_prompt(prompt)
126
+ if negative_prompt:
127
+ negative_prompt = negative_prompt.encode('utf-8').decode('utf-8')
128
+
129
+ validate_dimensions(width, height)
130
+
131
+ # Set up generation
132
+ generator = utils.seed_everything(seed)
133
+
134
+ width, height = utils.preprocess_image_dimensions(width, height)
135
+
136
+ # Set up pipeline
137
+ backup_scheduler = pipe.scheduler
138
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler)
139
+
140
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
141
+
142
+
143
  return None
144
  except GenerationError as e:
145
  logger.warning(f"Generation validation error: {str(e)}")
utils.py CHANGED
@@ -42,3 +42,37 @@ def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str]
42
  except Exception as e:
43
  logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
44
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
  logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
44
  raise
45
+
46
+ def seed_everything(seed: int) -> torch.Generator:
47
+ torch.manual_seed(seed)
48
+ torch.cuda.manual_seed_all(seed)
49
+ np.random.seed(seed)
50
+ generator = torch.Generator()
51
+ generator.manual_seed(seed)
52
+ return generator
53
+
54
+ def preprocess_image_dimensions(width, height):
55
+ if width % 8 != 0:
56
+ width = width - (width % 8)
57
+ if height % 8 != 0:
58
+ height = height - (height % 8)
59
+ return width, height
60
+
61
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
62
+ scheduler_factory_map = {
63
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
64
+ scheduler_config, use_karras_sigmas=True
65
+ ),
66
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
67
+ scheduler_config, use_karras_sigmas=True
68
+ ),
69
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
70
+ scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
71
+ ),
72
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
73
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
74
+ scheduler_config
75
+ ),
76
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
77
+ }
78
+ return scheduler_factory_map.get(name, lambda: None)()