Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		lzyhha
		
	commited on
		
		
					Commit 
							
							·
						
						24d1968
	
1
								Parent(s):
							
							e2a1c5c
								
space
Browse files- visualcloze.py +16 -16
    	
        visualcloze.py
    CHANGED
    
    | @@ -91,26 +91,26 @@ class VisualClozeModel: | |
| 91 | 
             
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 92 | 
             
                    self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
         | 
| 93 |  | 
| 94 | 
            -
                    #  | 
| 95 | 
            -
                     | 
| 96 | 
            -
                     | 
| 97 |  | 
| 98 | 
            -
                    #  | 
| 99 | 
            -
                     | 
| 100 | 
            -
                     | 
| 101 | 
            -
                     | 
| 102 |  | 
| 103 | 
            -
                    #  | 
| 104 | 
            -
                     | 
| 105 | 
            -
                     | 
| 106 | 
            -
                     | 
| 107 |  | 
| 108 | 
            -
                     | 
| 109 |  | 
| 110 | 
            -
                    #  | 
| 111 | 
            -
                     | 
| 112 | 
            -
                     | 
| 113 | 
            -
                     | 
| 114 |  | 
| 115 | 
             
                    # Initialize sampler
         | 
| 116 | 
             
                    transport = create_transport(
         | 
|  | |
| 91 | 
             
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 92 | 
             
                    self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
         | 
| 93 |  | 
| 94 | 
            +
                    # Initialize model
         | 
| 95 | 
            +
                    print("Initializing model...")
         | 
| 96 | 
            +
                    self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
         | 
| 97 |  | 
| 98 | 
            +
                    # Initialize VAE
         | 
| 99 | 
            +
                    print("Initializing VAE...")
         | 
| 100 | 
            +
                    self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
         | 
| 101 | 
            +
                    self.ae.requires_grad_(False)
         | 
| 102 |  | 
| 103 | 
            +
                    # Initialize text encoders
         | 
| 104 | 
            +
                    print("Initializing text encoders...")
         | 
| 105 | 
            +
                    self.t5 = load_t5(self.device, max_length=self.max_length)
         | 
| 106 | 
            +
                    self.clip = load_clip(self.device)
         | 
| 107 |  | 
| 108 | 
            +
                    self.model.eval().to(self.device, dtype=self.dtype)
         | 
| 109 |  | 
| 110 | 
            +
                    # Load model weights
         | 
| 111 | 
            +
                    ckpt = torch.load(model_path)
         | 
| 112 | 
            +
                    self.model.load_state_dict(ckpt, strict=False)
         | 
| 113 | 
            +
                    del ckpt
         | 
| 114 |  | 
| 115 | 
             
                    # Initialize sampler
         | 
| 116 | 
             
                    transport = create_transport(
         | 
 
			
