My-AI-Projects commited on
Commit
9378dc5
Β·
verified Β·
1 Parent(s): 86eb10c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -28
app.py CHANGED
@@ -12,9 +12,8 @@ import gradio as gr
12
  # Download the model files
13
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
14
 
15
- # Function to load models
16
  def load_models():
17
- # Load models on demand to reduce initial memory footprint
18
  text_encoder = ChatGLMModel.from_pretrained(
19
  os.path.join(ckpt_dir, 'text_encoder'),
20
  torch_dtype=torch.float16).half()
@@ -23,17 +22,16 @@ def load_models():
23
  scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler"))
24
  unet = UNet2DConditionModel.from_pretrained(os.path.join(ckpt_dir, "unet"), revision=None).half()
25
 
26
- pipe = StableDiffusionXLPipeline(
27
- vae=vae,
28
- text_encoder=text_encoder,
29
- tokenizer=tokenizer,
30
- unet=unet,
31
- scheduler=scheduler,
32
- force_zeros_for_empty_prompt=False)
33
- pipe = pipe.to("cuda")
34
-
35
- return pipe
36
 
 
37
  pipe = load_models()
38
 
39
  @spaces.GPU(duration=200)
@@ -43,9 +41,10 @@ def generate_image(prompt, negative_prompt, height, width, num_inference_steps,
43
  else:
44
  seed = int(seed) # Ensure seed is an integer
45
 
46
- # Move the model to the GPU for inference
47
  with torch.no_grad():
48
- image = pipe(
 
49
  prompt=prompt,
50
  negative_prompt=negative_prompt,
51
  height=height,
@@ -53,20 +52,13 @@ def generate_image(prompt, negative_prompt, height, width, num_inference_steps,
53
  num_inference_steps=num_inference_steps,
54
  guidance_scale=guidance_scale,
55
  num_images_per_prompt=num_images_per_prompt,
56
- generator=torch.Generator(pipe.device).manual_seed(seed)
57
- ).images
58
-
 
59
  return image, seed
60
 
61
- description = """
62
- <p align="center">Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis</p>
63
- <p><center>
64
- <a href="https://kolors.kuaishou.com/" target="_blank">[Official Website]</a>
65
- <a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf" target="_blank">[Tech Report]</a>
66
- <a href="https://huggingface.co/Kwai-Kolors/Kolors" target="_blank">[Model Page]</a>
67
- <a href="https://github.com/Kwai-Kolors/Kolors" target="_blank">[Github]</a>
68
- </center></p>
69
- """
70
 
71
  # Gradio interface
72
  iface = gr.Interface(
@@ -90,8 +82,7 @@ iface = gr.Interface(
90
  gr.Number(label="Seed Used")
91
  ],
92
  title="Kolors",
93
- description=description,
94
  theme='bethecloud/storj_theme',
95
  )
96
 
97
- iface.launch() # Set debug=False for production
 
12
  # Download the model files
13
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
14
 
15
+ # Function to load models on demand
16
  def load_models():
 
17
  text_encoder = ChatGLMModel.from_pretrained(
18
  os.path.join(ckpt_dir, 'text_encoder'),
19
  torch_dtype=torch.float16).half()
 
22
  scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler"))
23
  unet = UNet2DConditionModel.from_pretrained(os.path.join(ckpt_dir, "unet"), revision=None).half()
24
 
25
+ return StableDiffusionXLPipeline(
26
+ vae=vae,
27
+ text_encoder=text_encoder,
28
+ tokenizer=tokenizer,
29
+ unet=unet,
30
+ scheduler=scheduler,
31
+ force_zeros_for_empty_prompt=False
32
+ ).to("cuda")
 
 
33
 
34
+ # Create a global variable to hold the pipeline
35
  pipe = load_models()
36
 
37
  @spaces.GPU(duration=200)
 
41
  else:
42
  seed = int(seed) # Ensure seed is an integer
43
 
44
+ # Move the model to the GPU for inference and clear unnecessary variables
45
  with torch.no_grad():
46
+ generator = torch.Generator(pipe.device).manual_seed(seed)
47
+ result = pipe(
48
  prompt=prompt,
49
  negative_prompt=negative_prompt,
50
  height=height,
 
52
  num_inference_steps=num_inference_steps,
53
  guidance_scale=guidance_scale,
54
  num_images_per_prompt=num_images_per_prompt,
55
+ generator=generator
56
+ )
57
+ image = result.images
58
+
59
  return image, seed
60
 
61
+
 
 
 
 
 
 
 
 
62
 
63
  # Gradio interface
64
  iface = gr.Interface(
 
82
  gr.Number(label="Seed Used")
83
  ],
84
  title="Kolors",
 
85
  theme='bethecloud/storj_theme',
86
  )
87
 
88
+ iface.launch()