cai-qi commited on
Commit
8168e43
·
verified ·
1 Parent(s): 4bd52e1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from hi_diffusers import HiDreamImagePipeline
4
+ from hi_diffusers import HiDreamImageTransformer2DModel
5
+ from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
6
+ from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
7
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
8
+
9
+ MODEL_PREFIX = "HiDream-ai"
10
+ LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
11
+
12
+ # Model configurations
13
+ MODEL_CONFIGS = {
14
+ "dev": {
15
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
16
+ "guidance_scale": 0.0,
17
+ "num_inference_steps": 28,
18
+ "shift": 6.0,
19
+ "scheduler": FlashFlowMatchEulerDiscreteScheduler
20
+ },
21
+ "full": {
22
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
23
+ "guidance_scale": 5.0,
24
+ "num_inference_steps": 50,
25
+ "shift": 3.0,
26
+ "scheduler": FlowUniPCMultistepScheduler
27
+ },
28
+ "fast": {
29
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
30
+ "guidance_scale": 0.0,
31
+ "num_inference_steps": 16,
32
+ "shift": 3.0,
33
+ "scheduler": FlashFlowMatchEulerDiscreteScheduler
34
+ }
35
+ }
36
+
37
+ # Resolution options
38
+ RESOLUTION_OPTIONS = [
39
+ "1024 × 1024 (Square)",
40
+ "768 × 1360 (Portrait)",
41
+ "1360 × 768 (Landscape)",
42
+ "880 × 1168 (Portrait)",
43
+ "1168 × 880 (Landscape)",
44
+ "1248 × 832 (Landscape)",
45
+ "832 × 1248 (Portrait)"
46
+ ]
47
+
48
+ # Load models
49
+ def load_models(model_type):
50
+ config = MODEL_CONFIGS[model_type]
51
+ pretrained_model_name_or_path = config["path"]
52
+ scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
53
+
54
+ tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
55
+ LLAMA_MODEL_NAME,
56
+ use_fast=False)
57
+
58
+ text_encoder_4 = LlamaForCausalLM.from_pretrained(
59
+ LLAMA_MODEL_NAME,
60
+ output_hidden_states=True,
61
+ output_attentions=True,
62
+ torch_dtype=torch.bfloat16).to("cuda")
63
+
64
+ transformer = HiDreamImageTransformer2DModel.from_pretrained(
65
+ pretrained_model_name_or_path,
66
+ subfolder="transformer",
67
+ torch_dtype=torch.bfloat16).to("cuda")
68
+
69
+ pipe = HiDreamImagePipeline.from_pretrained(
70
+ pretrained_model_name_or_path,
71
+ scheduler=scheduler,
72
+ tokenizer_4=tokenizer_4,
73
+ text_encoder_4=text_encoder_4,
74
+ torch_dtype=torch.bfloat16
75
+ ).to("cuda", torch.bfloat16)
76
+ pipe.transformer = transformer
77
+
78
+ return pipe, config
79
+
80
+ # Parse resolution string to get height and width
81
+ def parse_resolution(resolution_str):
82
+ if "1024 × 1024" in resolution_str:
83
+ return 1024, 1024
84
+ elif "768 × 1360" in resolution_str:
85
+ return 768, 1360
86
+ elif "1360 × 768" in resolution_str:
87
+ return 1360, 768
88
+ elif "880 × 1168" in resolution_str:
89
+ return 880, 1168
90
+ elif "1168 × 880" in resolution_str:
91
+ return 1168, 880
92
+ elif "1248 × 832" in resolution_str:
93
+ return 1248, 832
94
+ elif "832 × 1248" in resolution_str:
95
+ return 832, 1248
96
+ else:
97
+ return 1024, 1024 # Default fallback
98
+
99
+ # Generate image function
100
+ def generate_image(model_type, prompt, resolution, seed):
101
+ global pipe, current_model
102
+
103
+ # Get configuration for current model
104
+ config = MODEL_CONFIGS[model_type]
105
+ guidance_scale = config["guidance_scale"]
106
+ num_inference_steps = config["num_inference_steps"]
107
+
108
+ # Parse resolution
109
+ height, width = parse_resolution(resolution)
110
+
111
+ # Handle seed
112
+ if seed == -1:
113
+ seed = torch.randint(0, 1000000, (1,)).item()
114
+
115
+ generator = torch.Generator("cuda").manual_seed(seed)
116
+
117
+ images = pipe(
118
+ prompt,
119
+ height=height,
120
+ width=width,
121
+ guidance_scale=guidance_scale,
122
+ num_inference_steps=num_inference_steps,
123
+ num_images_per_prompt=1,
124
+ generator=generator
125
+ ).images
126
+
127
+ return images[0], seed
128
+
129
+ # Initialize with default model
130
+ print("Loading default model (full)...")
131
+ current_model = "fast"
132
+ pipe, _ = load_models(current_model)
133
+ print("Model loaded successfully!")
134
+
135
+ # Create Gradio interface
136
+ with gr.Blocks(title="HiDream Image Generator") as demo:
137
+ gr.Markdown("# HiDream Image Generator")
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ model_type = gr.Radio(
142
+ choices=list(MODEL_CONFIGS.keys()),
143
+ value="full",
144
+ label="Model Type",
145
+ info="Select model variant"
146
+ )
147
+
148
+ prompt = gr.Textbox(
149
+ label="Prompt",
150
+ placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
151
+ lines=3
152
+ )
153
+
154
+ resolution = gr.Radio(
155
+ choices=RESOLUTION_OPTIONS,
156
+ value=RESOLUTION_OPTIONS[0],
157
+ label="Resolution",
158
+ info="Select image resolution"
159
+ )
160
+
161
+ seed = gr.Number(
162
+ label="Seed (use -1 for random)",
163
+ value=-1,
164
+ precision=0
165
+ )
166
+
167
+ generate_btn = gr.Button("Generate Image")
168
+ seed_used = gr.Number(label="Seed Used", interactive=False)
169
+
170
+ with gr.Column():
171
+ output_image = gr.Image(label="Generated Image", type="pil")
172
+
173
+ generate_btn.click(
174
+ fn=generate_image,
175
+ inputs=[model_type, prompt, resolution, seed],
176
+ outputs=[output_image, seed_used]
177
+ )
178
+
179
+ # Launch app
180
+ if __name__ == "__main__":
181
+ demo.launch()