Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
    	
        src/f5_tts/infer/utils_infer.py
    CHANGED
    
    | @@ -135,12 +135,10 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev | |
| 135 | 
             
            asr_pipe = None
         | 
| 136 |  | 
| 137 |  | 
| 138 | 
            -
            def initialize_asr_pipeline(device=device, dtype=None):
         | 
| 139 | 
             
                if dtype is None:
         | 
| 140 | 
             
                    dtype = (
         | 
| 141 | 
            -
                        torch.float16
         | 
| 142 | 
            -
                        if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
         | 
| 143 | 
            -
                        else torch.float32
         | 
| 144 | 
             
                    )
         | 
| 145 | 
             
                global asr_pipe
         | 
| 146 | 
             
                asr_pipe = pipeline(
         | 
| @@ -170,12 +168,10 @@ def transcribe(ref_audio, language=None): | |
| 170 | 
             
            # load model checkpoint for inference
         | 
| 171 |  | 
| 172 |  | 
| 173 | 
            -
            def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
         | 
| 174 | 
             
                if dtype is None:
         | 
| 175 | 
             
                    dtype = (
         | 
| 176 | 
            -
                        torch.float16
         | 
| 177 | 
            -
                        if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
         | 
| 178 | 
            -
                        else torch.float32
         | 
| 179 | 
             
                    )
         | 
| 180 | 
             
                model = model.to(dtype)
         | 
| 181 |  | 
|  | |
| 135 | 
             
            asr_pipe = None
         | 
| 136 |  | 
| 137 |  | 
| 138 | 
            +
            def initialize_asr_pipeline(device: str = device, dtype=None):
         | 
| 139 | 
             
                if dtype is None:
         | 
| 140 | 
             
                    dtype = (
         | 
| 141 | 
            +
                        torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
         | 
|  | |
|  | |
| 142 | 
             
                    )
         | 
| 143 | 
             
                global asr_pipe
         | 
| 144 | 
             
                asr_pipe = pipeline(
         | 
|  | |
| 168 | 
             
            # load model checkpoint for inference
         | 
| 169 |  | 
| 170 |  | 
| 171 | 
            +
            def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
         | 
| 172 | 
             
                if dtype is None:
         | 
| 173 | 
             
                    dtype = (
         | 
| 174 | 
            +
                        torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
         | 
|  | |
|  | |
| 175 | 
             
                    )
         | 
| 176 | 
             
                model = model.to(dtype)
         | 
| 177 |  | 
