gokaygokay commited on
Commit
92211ba
·
verified ·
1 Parent(s): 25e3823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -32,8 +32,11 @@ vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captione
32
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
33
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
34
 
35
- # SD3
36
- sd3_pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
 
 
 
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1344
@@ -90,8 +93,10 @@ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height,
90
  seed = random.randint(0, MAX_SEED)
91
 
92
  generator = torch.Generator().manual_seed(seed)
 
 
93
 
94
- image = sd3_pipe(
95
  prompt=prompt,
96
  negative_prompt=negative_prompt,
97
  guidance_scale=guidance_scale,
@@ -143,7 +148,7 @@ def img2img_generate(
143
 
144
  generator = torch.Generator().manual_seed(seed)
145
 
146
- img2img_pipe = StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
147
 
148
  init_image = init_image.resize((768, 768))
149
 
 
32
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
33
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
34
 
35
+ def load_pipeline(pipeline_type):
36
+ if pipeline_type == "text2img":
37
+ return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
38
+ elif pipeline_type == "img2img":
39
+ return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1344
 
93
  seed = random.randint(0, MAX_SEED)
94
 
95
  generator = torch.Generator().manual_seed(seed)
96
+
97
+ pipe = load_pipeline("text2img")
98
 
99
+ image = pipe(
100
  prompt=prompt,
101
  negative_prompt=negative_prompt,
102
  guidance_scale=guidance_scale,
 
148
 
149
  generator = torch.Generator().manual_seed(seed)
150
 
151
+ img2img_pipe = load_pipeline("img2img")
152
 
153
  init_image = init_image.resize((768, 768))
154