Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		zhzluke96
		
	commited on
		
		
					Commit 
							
							·
						
						8f52106
	
1
								Parent(s):
							
							0129fb6
								
update
Browse files- modules/models.py +35 -8
    	
        modules/models.py
    CHANGED
    
    | @@ -1,18 +1,23 @@ | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            from modules.ChatTTS import ChatTTS
         | 
| 3 | 
             
            from modules import config
         | 
| 4 | 
             
            from modules.devices import devices
         | 
| 5 |  | 
| 6 | 
             
            import logging
         | 
|  | |
| 7 |  | 
| 8 | 
             
            logger = logging.getLogger(__name__)
         | 
|  | |
| 9 | 
             
            chat_tts = None
         | 
|  | |
| 10 |  | 
| 11 |  | 
| 12 | 
            -
            def  | 
| 13 | 
             
                global chat_tts
         | 
| 14 | 
             
                if chat_tts:
         | 
| 15 | 
            -
                     | 
|  | |
| 16 |  | 
| 17 | 
             
                chat_tts = ChatTTS.Chat()
         | 
| 18 | 
             
                chat_tts.load_models(
         | 
| @@ -28,18 +33,40 @@ def load_chat_tts(): | |
| 28 | 
             
                )
         | 
| 29 |  | 
| 30 | 
             
                devices.torch_gc()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 | 
             
                return chat_tts
         | 
| 33 |  | 
| 34 |  | 
| 35 | 
            -
            def  | 
| 36 | 
            -
                logging.info(" | 
| 37 | 
             
                global chat_tts
         | 
|  | |
| 38 | 
             
                if chat_tts:
         | 
|  | |
|  | |
|  | |
|  | |
| 39 | 
             
                    if torch.cuda.is_available():
         | 
| 40 | 
            -
                        for model_name, model in chat_tts.pretrain_models.items():
         | 
| 41 | 
            -
                            if isinstance(model, torch.nn.Module):
         | 
| 42 | 
            -
                                model.cpu()
         | 
| 43 | 
             
                        torch.cuda.empty_cache()
         | 
|  | |
| 44 | 
             
                chat_tts = None
         | 
| 45 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import threading
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from modules.ChatTTS import ChatTTS
         | 
| 4 | 
             
            from modules import config
         | 
| 5 | 
             
            from modules.devices import devices
         | 
| 6 |  | 
| 7 | 
             
            import logging
         | 
| 8 | 
            +
            import gc
         | 
| 9 |  | 
| 10 | 
             
            logger = logging.getLogger(__name__)
         | 
| 11 | 
            +
             | 
| 12 | 
             
            chat_tts = None
         | 
| 13 | 
            +
            load_event = threading.Event()
         | 
| 14 |  | 
| 15 |  | 
| 16 | 
            +
            def load_chat_tts_in_thread():
         | 
| 17 | 
             
                global chat_tts
         | 
| 18 | 
             
                if chat_tts:
         | 
| 19 | 
            +
                    load_event.set()  # 如果已经加载过,直接设置事件
         | 
| 20 | 
            +
                    return
         | 
| 21 |  | 
| 22 | 
             
                chat_tts = ChatTTS.Chat()
         | 
| 23 | 
             
                chat_tts.load_models(
         | 
|  | |
| 33 | 
             
                )
         | 
| 34 |  | 
| 35 | 
             
                devices.torch_gc()
         | 
| 36 | 
            +
                load_event.set()  # 设置事件,表示加载完成
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def initialize_chat_tts():
         | 
| 40 | 
            +
                model_thread = threading.Thread(target=load_chat_tts_in_thread)
         | 
| 41 | 
            +
                model_thread.start()
         | 
| 42 |  | 
| 43 | 
            +
             | 
| 44 | 
            +
            def load_chat_tts():
         | 
| 45 | 
            +
                if chat_tts is None:
         | 
| 46 | 
            +
                    initialize_chat_tts()
         | 
| 47 | 
            +
                load_event.wait()
         | 
| 48 | 
             
                return chat_tts
         | 
| 49 |  | 
| 50 |  | 
| 51 | 
            +
            def unload_chat_tts():
         | 
| 52 | 
            +
                logging.info("Unloading ChatTTS models")
         | 
| 53 | 
             
                global chat_tts
         | 
| 54 | 
            +
             | 
| 55 | 
             
                if chat_tts:
         | 
| 56 | 
            +
                    for model_name, model in chat_tts.pretrain_models.items():
         | 
| 57 | 
            +
                        if isinstance(model, torch.nn.Module):
         | 
| 58 | 
            +
                            model.cpu()
         | 
| 59 | 
            +
                            del model
         | 
| 60 | 
             
                    if torch.cuda.is_available():
         | 
|  | |
|  | |
|  | |
| 61 | 
             
                        torch.cuda.empty_cache()
         | 
| 62 | 
            +
                gc.collect()
         | 
| 63 | 
             
                chat_tts = None
         | 
| 64 | 
            +
                logger.info("ChatTTS models unloaded")
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def reload_chat_tts():
         | 
| 68 | 
            +
                logging.info("Reloading ChatTTS models")
         | 
| 69 | 
            +
                unload_chat_tts()
         | 
| 70 | 
            +
                instance = load_chat_tts()
         | 
| 71 | 
            +
                logger.info("ChatTTS models reloaded")
         | 
| 72 | 
            +
                return instance
         | 
