DonImages commited on
Commit
4a6aac0
·
verified ·
1 Parent(s): 509d782

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -17,7 +17,7 @@ pipeline = None
17
 
18
  # Function to load the Stable Diffusion pipeline (called only ONCE)
19
  def load_pipeline():
20
- global pipeline # Use the global keyword to modify the global variable
21
  try:
22
  pipeline = StableDiffusion3Pipeline.from_pretrained(
23
  model_id,
@@ -27,31 +27,31 @@ def load_pipeline():
27
  )
28
  except Exception as e:
29
  print(f"Error loading model: {e}")
30
- return f"Error loading model: {e}" # Return error message
31
 
32
  pipeline.enable_model_cpu_offload()
33
  pipeline.enable_attention_slicing()
34
- return "Model loaded successfully" # Return success message
35
 
36
- # Function for image generation (now decorated)
37
  @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
38
  def generate_image(prompt):
39
  global pipeline
40
- if pipeline is None: # Check if pipeline is loaded
41
- return "Model not loaded. Please wait." # Return message if not loaded
42
 
43
- # Load and apply LoRA (file is already in the Space)
44
- lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file
45
- lora_path = os.path.join("./", lora_filename) # Construct the path
46
  print(f"Loading LoRA from: {lora_path}")
47
 
48
  try:
49
- if os.path.exists(lora_path): # check if the file exists
50
  lora_weights = load_file(lora_path)
51
  text_encoder = pipeline.text_encoder
52
  text_encoder.load_state_dict(lora_weights, strict=False)
53
- except Exception as e:
54
- return f"Error loading LoRA: {e}"
 
 
55
 
56
  try:
57
  image = pipeline(prompt).images[0]
@@ -60,14 +60,13 @@ def generate_image(prompt):
60
  return f"Error generating image: {e}"
61
 
62
 
63
- # Create the Gradio interface
64
  with gr.Blocks() as demo:
65
  prompt_input = gr.Textbox(label="Prompt")
66
  image_output = gr.Image(label="Generated Image")
67
  generate_button = gr.Button("Generate")
68
- load_model_button = gr.Button("Load Model") # Button to load model
69
 
70
- load_model_button.click(fn=load_pipeline, outputs=load_model_button) # Call load_pipeline
71
 
72
  generate_button.click(
73
  fn=generate_image,
 
17
 
18
  # Function to load the Stable Diffusion pipeline (called only ONCE)
19
  def load_pipeline():
20
+ global pipeline
21
  try:
22
  pipeline = StableDiffusion3Pipeline.from_pretrained(
23
  model_id,
 
27
  )
28
  except Exception as e:
29
  print(f"Error loading model: {e}")
30
+ return f"Error loading model: {e}"
31
 
32
  pipeline.enable_model_cpu_offload()
33
  pipeline.enable_attention_slicing()
34
+ return "Model loaded successfully"
35
 
 
36
  @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
37
  def generate_image(prompt):
38
  global pipeline
39
+ if pipeline is None:
40
+ return "Model not loaded. Please wait."
41
 
42
+ lora_filename = "lora_trained_model.safetensors"
43
+ lora_path = os.path.join("./", lora_filename)
 
44
  print(f"Loading LoRA from: {lora_path}")
45
 
46
  try:
47
+ if os.path.exists(lora_path):
48
  lora_weights = load_file(lora_path)
49
  text_encoder = pipeline.text_encoder
50
  text_encoder.load_state_dict(lora_weights, strict=False)
51
+ else:
52
+ return f"Error: LoRA file not found at {lora_path}"
53
+ except Exception as e:
54
+ return f"Error loading LoRA: {e}"
55
 
56
  try:
57
  image = pipeline(prompt).images[0]
 
60
  return f"Error generating image: {e}"
61
 
62
 
 
63
  with gr.Blocks() as demo:
64
  prompt_input = gr.Textbox(label="Prompt")
65
  image_output = gr.Image(label="Generated Image")
66
  generate_button = gr.Button("Generate")
67
+ load_model_button = gr.Button("Load Model")
68
 
69
+ load_model_button.click(fn=load_pipeline, outputs=load_model_button)
70
 
71
  generate_button.click(
72
  fn=generate_image,