Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -20,7 +20,9 @@ DEFAULT_INFERENCE_STEPS = 1 | |
| 20 |  | 
| 21 | 
             
            # Device and model setup
         | 
| 22 | 
             
            dtype = torch.float16
         | 
| 23 | 
            -
            #  | 
|  | |
|  | |
| 24 | 
             
            lora_weights_path = hf_hub_download(
         | 
| 25 | 
             
                repo_id="hugovntr/flux-schnell-realism",
         | 
| 26 | 
             
                filename="schnell-realism_v2.3.safetensors",
         | 
| @@ -30,17 +32,17 @@ pipe = FluxWithCFGPipeline.from_pretrained( | |
| 30 | 
             
                "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
         | 
| 31 | 
             
            )
         | 
| 32 | 
             
            pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
         | 
| 33 | 
            -
            pipe.to( | 
| 34 |  | 
| 35 | 
            -
            # Load the LoRA weights | 
| 36 | 
             
            pipe.load_lora_weights(lora_weights_path, adapter_name="better")
         | 
| 37 | 
             
            pipe.set_adapters(["better"], adapter_weights=[1.0])
         | 
| 38 | 
             
            pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
         | 
| 39 | 
             
            pipe.unload_lora_weights()
         | 
| 40 |  | 
| 41 | 
             
            # Memory optimizations
         | 
| 42 | 
            -
            pipe.transformer.to(memory_format=torch.channels_last) | 
| 43 | 
            -
            pipe.enable_xformers_memory_efficient_attention() | 
| 44 |  | 
| 45 | 
             
            # CUDA Graph setup
         | 
| 46 | 
             
            static_inputs = None
         | 
| @@ -51,7 +53,6 @@ def setup_cuda_graph(prompt, height, width, num_inference_steps): | |
| 51 | 
             
                global static_inputs, static_model, graph
         | 
| 52 |  | 
| 53 | 
             
                batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 54 | 
            -
                device = "cuda"
         | 
| 55 | 
             
                num_images_per_prompt = 1
         | 
| 56 |  | 
| 57 | 
             
                prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
         | 
| @@ -91,11 +92,11 @@ def setup_cuda_graph(prompt, height, width, num_inference_steps): | |
| 91 | 
             
                guidance = torch.full([1], 3.5, device=device, dtype=torch.float16).expand(latents.shape[0]) if pipe.transformer.config.guidance_embeds else None
         | 
| 92 |  | 
| 93 | 
             
                static_inputs = {
         | 
| 94 | 
            -
                    "hidden_states": latents,
         | 
| 95 | 
            -
                    "timestep": timesteps,
         | 
| 96 | 
            -
                    "guidance": guidance,
         | 
| 97 | 
            -
                    "pooled_projections": pooled_prompt_embeds,
         | 
| 98 | 
            -
                    "encoder_hidden_states": prompt_embeds,
         | 
| 99 | 
             
                    "txt_ids": text_ids,
         | 
| 100 | 
             
                    "img_ids": latent_image_ids,
         | 
| 101 | 
             
                    "joint_attention_kwargs": None,
         | 
| @@ -108,7 +109,7 @@ def setup_cuda_graph(prompt, height, width, num_inference_steps): | |
| 108 | 
             
                    static_output = static_model(**static_inputs)
         | 
| 109 |  | 
| 110 | 
             
            # Inference function
         | 
| 111 | 
            -
            @spaces.GPU(duration=25)
         | 
| 112 | 
             
            def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
         | 
| 113 | 
             
                global static_inputs, graph
         | 
| 114 |  | 
|  | |
| 20 |  | 
| 21 | 
             
            # Device and model setup
         | 
| 22 | 
             
            dtype = torch.float16
         | 
| 23 | 
            +
            device = "cuda"  # Explicitly set device to CUDA
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Download the LoRA weights
         | 
| 26 | 
             
            lora_weights_path = hf_hub_download(
         | 
| 27 | 
             
                repo_id="hugovntr/flux-schnell-realism",
         | 
| 28 | 
             
                filename="schnell-realism_v2.3.safetensors",
         | 
|  | |
| 32 | 
             
                "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
         | 
| 33 | 
             
            )
         | 
| 34 | 
             
            pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
         | 
| 35 | 
            +
            pipe.to(device)  # Move the pipeline to CUDA
         | 
| 36 |  | 
| 37 | 
            +
            # Load the LoRA weights
         | 
| 38 | 
             
            pipe.load_lora_weights(lora_weights_path, adapter_name="better")
         | 
| 39 | 
             
            pipe.set_adapters(["better"], adapter_weights=[1.0])
         | 
| 40 | 
             
            pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
         | 
| 41 | 
             
            pipe.unload_lora_weights()
         | 
| 42 |  | 
| 43 | 
             
            # Memory optimizations
         | 
| 44 | 
            +
            pipe.transformer.to(memory_format=torch.channels_last)
         | 
| 45 | 
            +
            pipe.enable_xformers_memory_efficient_attention()
         | 
| 46 |  | 
| 47 | 
             
            # CUDA Graph setup
         | 
| 48 | 
             
            static_inputs = None
         | 
|  | |
| 53 | 
             
                global static_inputs, static_model, graph
         | 
| 54 |  | 
| 55 | 
             
                batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
|  | |
| 56 | 
             
                num_images_per_prompt = 1
         | 
| 57 |  | 
| 58 | 
             
                prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
         | 
|  | |
| 92 | 
             
                guidance = torch.full([1], 3.5, device=device, dtype=torch.float16).expand(latents.shape[0]) if pipe.transformer.config.guidance_embeds else None
         | 
| 93 |  | 
| 94 | 
             
                static_inputs = {
         | 
| 95 | 
            +
                    "hidden_states": latents.to(device),
         | 
| 96 | 
            +
                    "timestep": timesteps.to(device),
         | 
| 97 | 
            +
                    "guidance": guidance.to(device) if guidance is not None else None,
         | 
| 98 | 
            +
                    "pooled_projections": pooled_prompt_embeds.to(device),
         | 
| 99 | 
            +
                    "encoder_hidden_states": prompt_embeds.to(device),
         | 
| 100 | 
             
                    "txt_ids": text_ids,
         | 
| 101 | 
             
                    "img_ids": latent_image_ids,
         | 
| 102 | 
             
                    "joint_attention_kwargs": None,
         | 
|  | |
| 109 | 
             
                    static_output = static_model(**static_inputs)
         | 
| 110 |  | 
| 111 | 
             
            # Inference function
         | 
| 112 | 
            +
            # @spaces.GPU(duration=25) # Remove decorator
         | 
| 113 | 
             
            def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
         | 
| 114 | 
             
                global static_inputs, graph
         | 
| 115 |  | 
