Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 3 files
Browse files- model.py +2 -2
- multit2i.py +26 -5
    	
        model.py
    CHANGED
    
    | @@ -14,9 +14,9 @@ models = [ | |
| 14 | 
             
                'digiplay/majicMIX_realistic_v7',
         | 
| 15 | 
             
                'votepurchase/counterfeitV30_v30',
         | 
| 16 | 
             
                'Meina/MeinaMix_V11',
         | 
| 17 | 
            -
                'John6666/cute-illustration-style-reinforced-model-v61-sd15',
         | 
| 18 | 
             
                'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
         | 
| 19 | 
            -
                ' | 
|  | |
| 20 | 
             
                'Eugeoter/artiwaifu-diffusion-1.0',
         | 
| 21 | 
             
                'Raelina/Rae-Diffusion-XL-V2',
         | 
| 22 | 
             
                'Raelina/Raemu-XL-V4',
         | 
|  | |
| 14 | 
             
                'digiplay/majicMIX_realistic_v7',
         | 
| 15 | 
             
                'votepurchase/counterfeitV30_v30',
         | 
| 16 | 
             
                'Meina/MeinaMix_V11',
         | 
|  | |
| 17 | 
             
                'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
         | 
| 18 | 
            +
                'KBlueLeaf/Kohaku-XL-Zeta',
         | 
| 19 | 
            +
                'kayfahaarukku/UrangDiffusion-1.2',
         | 
| 20 | 
             
                'Eugeoter/artiwaifu-diffusion-1.0',
         | 
| 21 | 
             
                'Raelina/Rae-Diffusion-XL-V2',
         | 
| 22 | 
             
                'Raelina/Raemu-XL-V4',
         | 
    	
        multit2i.py
    CHANGED
    
    | @@ -33,22 +33,43 @@ def is_repo_name(s): | |
| 33 | 
             
                return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
         | 
| 34 |  | 
| 35 |  | 
| 36 | 
            -
            def  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
                from huggingface_hub import HfApi
         | 
| 38 | 
             
                api = HfApi()
         | 
| 39 | 
             
                default_tags = ["diffusers"]
         | 
| 40 | 
             
                if not sort: sort = "last_modified"
         | 
|  | |
| 41 | 
             
                models = []
         | 
| 42 | 
             
                try:
         | 
| 43 | 
            -
                    model_infos = api.list_models(author=author,  | 
| 44 | 
            -
                                                   tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit | 
| 45 | 
             
                except Exception as e:
         | 
| 46 | 
             
                    print(f"Error: Failed to list models.")
         | 
| 47 | 
             
                    print(e)
         | 
| 48 | 
             
                    return models
         | 
| 49 | 
             
                for model in model_infos:
         | 
| 50 | 
            -
                    if not model.private and not model.gated | 
| 51 | 
            -
                        | 
|  | |
| 52 | 
             
                       models.append(model.id)
         | 
| 53 | 
             
                       if len(models) == limit: break
         | 
| 54 | 
             
                return models
         | 
|  | |
| 33 | 
             
                return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
         | 
| 34 |  | 
| 35 |  | 
| 36 | 
            +
            def get_status(model_name: str):
         | 
| 37 | 
            +
                from huggingface_hub import InferenceClient
         | 
| 38 | 
            +
                client = InferenceClient(timeout=10)
         | 
| 39 | 
            +
                return client.get_model_status(model_name)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def is_loadable(model_name: str, force_gpu: bool = False):
         | 
| 43 | 
            +
                try:
         | 
| 44 | 
            +
                    status = get_status(model_name)
         | 
| 45 | 
            +
                except Exception as e:
         | 
| 46 | 
            +
                    print(e)
         | 
| 47 | 
            +
                    print(f"Couldn't load {model_name}.")
         | 
| 48 | 
            +
                    return False
         | 
| 49 | 
            +
                gpu_state = isinstance(status.compute_type, dict) and "gpu" in status.compute_type.keys()
         | 
| 50 | 
            +
                if status is None or status.state not in ["Loadable", "Loaded"] or (force_gpu and not gpu_state):
         | 
| 51 | 
            +
                    print(f"Couldn't load {model_name}. Model state:'{status.state}', GPU:{gpu_state}")
         | 
| 52 | 
            +
                return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False):
         | 
| 56 | 
             
                from huggingface_hub import HfApi
         | 
| 57 | 
             
                api = HfApi()
         | 
| 58 | 
             
                default_tags = ["diffusers"]
         | 
| 59 | 
             
                if not sort: sort = "last_modified"
         | 
| 60 | 
            +
                limit = limit * 20 if check_status and force_gpu else limit * 5
         | 
| 61 | 
             
                models = []
         | 
| 62 | 
             
                try:
         | 
| 63 | 
            +
                    model_infos = api.list_models(author=author, task="text-to-image",
         | 
| 64 | 
            +
                                                   tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
         | 
| 65 | 
             
                except Exception as e:
         | 
| 66 | 
             
                    print(f"Error: Failed to list models.")
         | 
| 67 | 
             
                    print(e)
         | 
| 68 | 
             
                    return models
         | 
| 69 | 
             
                for model in model_infos:
         | 
| 70 | 
            +
                    if not model.private and not model.gated:
         | 
| 71 | 
            +
                       loadable = is_loadable(model.id, force_gpu) if check_status else True
         | 
| 72 | 
            +
                       if not_tag and not_tag in model.tags or not loadable: continue
         | 
| 73 | 
             
                       models.append(model.id)
         | 
| 74 | 
             
                       if len(models) == limit: break
         | 
| 75 | 
             
                return models
         | 
