prithivMLmods commited on
Commit
7d0e511
·
verified ·
1 Parent(s): b449e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -23,7 +23,7 @@ subprocess.run(
23
  )
24
 
25
  # -------------------------------
26
- # FLUX.1 IMAGE GENERATION SETUP
27
  # -------------------------------
28
  MAX_SEED = np.iinfo(np.int32).max
29
 
@@ -38,14 +38,21 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
38
  seed = random.randint(0, MAX_SEED)
39
  return seed
40
 
 
 
 
 
 
 
 
41
  from diffusers import DiffusionPipeline
42
 
43
  base_model = "black-forest-labs/FLUX.1-dev"
44
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
45
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
46
  trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
47
  pipe.load_lora_weights(lora_repo)
48
- pipe.enable_model_cpu_offload() # Enable CPU offload to manage GPU memory efficiently
49
 
50
  # Define style prompts for Flux.1
51
  style_list = [
@@ -90,15 +97,18 @@ def generate_image_flux(
90
  positive_prompt = apply_style(style_name, prompt)
91
  if trigger_word:
92
  positive_prompt = f"{trigger_word} {positive_prompt}"
93
- images = pipe(
94
- prompt=positive_prompt,
95
- width=width,
96
- height=height,
97
- guidance_scale=guidance_scale,
98
- num_inference_steps=28,
99
- num_images_per_prompt=1,
100
- output_type="pil",
101
- ).images
 
 
 
102
  image_paths = [save_image(img) for img in images]
103
  return image_paths, seed
104
 
@@ -111,7 +121,7 @@ smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Inst
111
  smol_model = AutoModelForImageTextToText.from_pretrained(
112
  "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
113
  _attn_implementation="flash_attention_2",
114
- torch_dtype=torch.bfloat16
115
  ).to("cuda:0")
116
 
117
  # -------------------------------
@@ -270,9 +280,9 @@ def generate(
270
  return_dict=True,
271
  return_tensors="pt",
272
  )
273
- # Explicitly cast pixel values to bfloat16 to match model weights.
274
  if "pixel_values" in inputs:
275
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
276
  inputs = inputs.to(smol_model.device)
277
 
278
  streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
 
23
  )
24
 
25
  # -------------------------------
26
+ # CONFIGURATION & UTILITY FUNCTIONS
27
  # -------------------------------
28
  MAX_SEED = np.iinfo(np.int32).max
29
 
 
38
  seed = random.randint(0, MAX_SEED)
39
  return seed
40
 
41
+ # Determine preferred torch dtype based on GPU support.
42
+ bf16_supported = torch.cuda.is_bf16_supported()
43
+ preferred_dtype = torch.bfloat16 if bf16_supported else torch.float16
44
+
45
+ # -------------------------------
46
+ # FLUX.1 IMAGE GENERATION SETUP
47
+ # -------------------------------
48
  from diffusers import DiffusionPipeline
49
 
50
  base_model = "black-forest-labs/FLUX.1-dev"
51
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=preferred_dtype)
52
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
53
  trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
54
  pipe.load_lora_weights(lora_repo)
55
+ pipe.to("cuda")
56
 
57
  # Define style prompts for Flux.1
58
  style_list = [
 
97
  positive_prompt = apply_style(style_name, prompt)
98
  if trigger_word:
99
  positive_prompt = f"{trigger_word} {positive_prompt}"
100
+ # Wrap the diffusion call in no_grad to avoid unnecessary gradient state.
101
+ with torch.no_grad():
102
+ images = pipe(
103
+ prompt=positive_prompt,
104
+ width=width,
105
+ height=height,
106
+ guidance_scale=guidance_scale,
107
+ num_inference_steps=28,
108
+ num_images_per_prompt=1,
109
+ output_type="pil",
110
+ ).images
111
+ torch.cuda.synchronize() # Ensure all CUDA operations have completed
112
  image_paths = [save_image(img) for img in images]
113
  return image_paths, seed
114
 
 
121
  smol_model = AutoModelForImageTextToText.from_pretrained(
122
  "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
123
  _attn_implementation="flash_attention_2",
124
+ torch_dtype=preferred_dtype
125
  ).to("cuda:0")
126
 
127
  # -------------------------------
 
280
  return_dict=True,
281
  return_tensors="pt",
282
  )
283
+ # Explicitly cast pixel values to the preferred dtype to match model weights.
284
  if "pixel_values" in inputs:
285
+ inputs["pixel_values"] = inputs["pixel_values"].to(preferred_dtype)
286
  inputs = inputs.to(smol_model.device)
287
 
288
  streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)