kayfahaarukku commited on
Commit
5272b56
·
verified ·
1 Parent(s): 48a4b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  import random
7
  import tqdm
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  # Enable TQDM progress tracking
11
  tqdm.monitor_interval = 0
@@ -17,11 +18,19 @@ def load_model():
17
  filename="AkashicPulse-v1.0-ft-ft.safetensors"
18
  )
19
 
20
- # Initialize standard SD pipeline
 
 
 
 
21
  pipe = StableDiffusionPipeline.from_single_file(
22
  model_path,
23
  torch_dtype=torch.float16,
24
- use_safetensors=True
 
 
 
 
25
  )
26
 
27
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
@@ -47,6 +56,10 @@ def generate_image(prompt, negative_prompt, use_defaults, resolution, guidance_s
47
  return
48
 
49
  width, height = map(int, resolution.split('x'))
 
 
 
 
50
  image = pipe(
51
  prompt,
52
  negative_prompt=negative_prompt,
@@ -56,7 +69,8 @@ def generate_image(prompt, negative_prompt, use_defaults, resolution, guidance_s
56
  num_inference_steps=num_inference_steps,
57
  generator=generator,
58
  callback=callback,
59
- callback_steps=1
 
60
  ).images[0]
61
 
62
  torch.cuda.empty_cache()
@@ -67,8 +81,12 @@ def generate_image(prompt, negative_prompt, use_defaults, resolution, guidance_s
67
 
68
  # Define Gradio interface
69
  def interface_fn(prompt, negative_prompt, use_defaults, resolution, guidance_scale, num_inference_steps, seed, randomize_seed, progress=gr.Progress()):
70
- image, seed, metadata_text = generate_image(prompt, negative_prompt, use_defaults, resolution, guidance_scale, num_inference_steps, seed, randomize_seed, progress)
71
- return image, seed, gr.update(value=metadata_text)
 
 
 
 
72
 
73
  def reset_inputs():
74
  return gr.update(value=''), gr.update(value=''), gr.update(value=True), gr.update(value='832x1216'), gr.update(value=7), gr.update(value=28), gr.update(value=0), gr.update(value=True), gr.update(value='')
 
6
  import random
7
  import tqdm
8
  from huggingface_hub import hf_hub_download
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
 
11
  # Enable TQDM progress tracking
12
  tqdm.monitor_interval = 0
 
18
  filename="AkashicPulse-v1.0-ft-ft.safetensors"
19
  )
20
 
21
+ # Initialize tokenizer and text encoder from standard SD 1.5
22
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
23
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
24
+
25
+ # Initialize pipeline with text encoder and tokenizer
26
  pipe = StableDiffusionPipeline.from_single_file(
27
  model_path,
28
  torch_dtype=torch.float16,
29
+ use_safetensors=True,
30
+ tokenizer=tokenizer,
31
+ text_encoder=text_encoder,
32
+ requires_safety_checker=False,
33
+ safety_checker=None
34
  )
35
 
36
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
56
  return
57
 
58
  width, height = map(int, resolution.split('x'))
59
+
60
+ # Add empty dict for additional kwargs
61
+ added_cond_kwargs = {"text_embeds": None, "time_ids": None}
62
+
63
  image = pipe(
64
  prompt,
65
  negative_prompt=negative_prompt,
 
69
  num_inference_steps=num_inference_steps,
70
  generator=generator,
71
  callback=callback,
72
+ callback_steps=1,
73
+ added_cond_kwargs=added_cond_kwargs
74
  ).images[0]
75
 
76
  torch.cuda.empty_cache()
 
81
 
82
  # Define Gradio interface
83
  def interface_fn(prompt, negative_prompt, use_defaults, resolution, guidance_scale, num_inference_steps, seed, randomize_seed, progress=gr.Progress()):
84
+ try:
85
+ image, seed, metadata_text = generate_image(prompt, negative_prompt, use_defaults, resolution, guidance_scale, num_inference_steps, seed, randomize_seed, progress)
86
+ return image, seed, gr.update(value=metadata_text)
87
+ except Exception as e:
88
+ print(f"Error generating image: {str(e)}")
89
+ raise e
90
 
91
  def reset_inputs():
92
  return gr.update(value=''), gr.update(value=''), gr.update(value=True), gr.update(value='832x1216'), gr.update(value=7), gr.update(value=28), gr.update(value=0), gr.update(value=True), gr.update(value='')