prithivMLmods commited on
Commit
b45fde6
·
verified ·
1 Parent(s): f264bf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -40,9 +40,8 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
40
 
41
  from diffusers import DiffusionPipeline
42
 
43
- # Use torch.float16 for better stability on NVIDIA GPUs
44
  base_model = "black-forest-labs/FLUX.1-dev"
45
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
46
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
47
  trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
48
  pipe.load_lora_weights(lora_repo)
@@ -90,17 +89,23 @@ def generate_image_flux(
90
  positive_prompt = apply_style(style_name, prompt)
91
  if trigger_word:
92
  positive_prompt = f"{trigger_word} {positive_prompt}"
93
- # Clear GPU cache before generating to help avoid memory fragmentation errors.
 
94
  torch.cuda.empty_cache()
95
- images = pipe(
96
- prompt=positive_prompt,
97
- width=width,
98
- height=height,
99
- guidance_scale=guidance_scale,
100
- num_inference_steps=28,
101
- num_images_per_prompt=1,
102
- output_type="pil",
103
- ).images
 
 
 
 
 
104
  image_paths = [save_image(img) for img in images]
105
  return image_paths, seed
106
 
 
40
 
41
  from diffusers import DiffusionPipeline
42
 
 
43
  base_model = "black-forest-labs/FLUX.1-dev"
44
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
45
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
46
  trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
47
  pipe.load_lora_weights(lora_repo)
 
89
  positive_prompt = apply_style(style_name, prompt)
90
  if trigger_word:
91
  positive_prompt = f"{trigger_word} {positive_prompt}"
92
+
93
+ # Clear cache before generation
94
  torch.cuda.empty_cache()
95
+
96
+ # Wrap the pipeline call in no_grad and autocast contexts
97
+ with torch.no_grad():
98
+ with torch.cuda.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
99
+ outputs = pipe(
100
+ prompt=positive_prompt,
101
+ width=width,
102
+ height=height,
103
+ guidance_scale=guidance_scale,
104
+ num_inference_steps=28,
105
+ num_images_per_prompt=1,
106
+ output_type="pil",
107
+ )
108
+ images = outputs.images
109
  image_paths = [save_image(img) for img in images]
110
  return image_paths, seed
111