DonImages commited on
Commit
7b1a432
·
verified ·
1 Parent(s): 5bcd142

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from diffusers import StableDiffusion3Pipeline
5
+ from safetensors.torch import load_file
6
+
7
+ # Access HF_TOKEN from environment variables
8
+ hf_token = os.getenv("HF_TOKEN")
9
+
10
+ # Specify the pre-trained model ID
11
+ model_id = "stabilityai/stable-diffusion-3.5-large"
12
+
13
+ # Lazy pipeline initialization
14
+ pipeline = None
15
+
16
+ # Function for image generation
17
+ @gr.GPU(duration=65)
18
+ def generate_image(prompt): # Remove lora_file input
19
+ global pipeline
20
+ if pipeline is None:
21
+ try:
22
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
23
+ model_id,
24
+ use_auth_token=hf_token,
25
+ torch_dtype=torch.float16,
26
+ cache_dir="./model_cache"
27
+ )
28
+ except Exception as e:
29
+ print(f"Error loading from cache: {e}")
30
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
31
+ model_id,
32
+ use_auth_token=hf_token,
33
+ torch_dtype=torch.float16,
34
+ local_files_only=False
35
+ )
36
+ pipeline.enable_model_cpu_offload()
37
+ pipeline.enable_attention_slicing()
38
+
39
+ # Load and apply LoRA (file is already in the Space)
40
+ lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file
41
+ lora_path = os.path.join("./", lora_filename) # Construct the path
42
+ print(f"Loading LoRA from: {lora_path}")
43
+
44
+ try:
45
+ if os.path.exists(lora_path): # check if the file exists
46
+ lora_weights = load_file(lora_path)
47
+ text_encoder = pipeline.text_encoder
48
+ text_encoder.load_state_dict(lora_weights, strict=False)
49
+ else:
50
+ return f"Error: LoRA file not found at {lora_path}"
51
+ except Exception as e:
52
+ return f"Error loading LoRA: {e}"
53
+
54
+ try:
55
+ image = pipeline(prompt).images[0]
56
+ return image
57
+ except Exception as e:
58
+ return f"Error generating image: {e}"
59
+
60
+
61
+ # Create the Gradio interface (remove lora_upload)
62
+ with gr.Blocks() as demo:
63
+ prompt_input = gr.Textbox(label="Prompt")
64
+ image_output = gr.Image(label="Generated Image")
65
+ generate_button = gr.Button("Generate")
66
+
67
+ generate_button.click(
68
+ fn=generate_image,
69
+ inputs=prompt_input, # Only prompt input now
70
+ outputs=image_output,
71
+ )
72
+
73
+ demo.launch()