Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 13 files
Browse files- README.md +13 -12
- app.py +155 -0
- model.py +32 -0
- multit2i.py +311 -0
- requirements.txt +12 -0
- tagger/character_series_dict.csv +0 -0
- tagger/danbooru_e621.csv +0 -0
- tagger/fl2sd3longcap.py +74 -0
- tagger/output.py +16 -0
- tagger/tag_group.csv +0 -0
- tagger/tagger.py +546 -0
- tagger/utils.py +45 -0
- tagger/v2.py +260 -0
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,13 @@ | |
| 1 | 
            -
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
            -
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 4.39.0
         | 
| 8 | 
            -
            app_file: app.py
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: Free Multi Models Text-to-Image Heavy-Armed Demo
         | 
| 3 | 
            +
            emoji: 🌐🌊
         | 
| 4 | 
            +
            colorFrom: blue
         | 
| 5 | 
            +
            colorTo: purple
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.39.0
         | 
| 8 | 
            +
            app_file: app.py
         | 
| 9 | 
            +
            short_description: Text-to-Image
         | 
| 10 | 
            +
            pinned: true
         | 
| 11 | 
            +
            ---
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,155 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from multit2i import (
         | 
| 3 | 
            +
                load_models,
         | 
| 4 | 
            +
                infer_multi,
         | 
| 5 | 
            +
                infer_multi_random,
         | 
| 6 | 
            +
                save_gallery_images,
         | 
| 7 | 
            +
                change_model,
         | 
| 8 | 
            +
                get_model_info_md,
         | 
| 9 | 
            +
                loaded_models,
         | 
| 10 | 
            +
                get_positive_prefix,
         | 
| 11 | 
            +
                get_positive_suffix,
         | 
| 12 | 
            +
                get_negative_prefix,
         | 
| 13 | 
            +
                get_negative_suffix,
         | 
| 14 | 
            +
                get_recom_prompt_type,
         | 
| 15 | 
            +
                set_recom_prompt_preset,
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
            from model import models
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from tagger.tagger import (
         | 
| 20 | 
            +
                predict_tags_wd,
         | 
| 21 | 
            +
                remove_specific_prompt,
         | 
| 22 | 
            +
                convert_danbooru_to_e621_prompt,
         | 
| 23 | 
            +
                insert_recom_prompt,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from tagger.fl2sd3longcap import predict_tags_fl2_sd3
         | 
| 26 | 
            +
            from tagger.v2 import (
         | 
| 27 | 
            +
                V2_ALL_MODELS,
         | 
| 28 | 
            +
                v2_random_prompt,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from tagger.utils import (
         | 
| 31 | 
            +
                V2_ASPECT_RATIO_OPTIONS,
         | 
| 32 | 
            +
                V2_RATING_OPTIONS,
         | 
| 33 | 
            +
                V2_LENGTH_OPTIONS,
         | 
| 34 | 
            +
                V2_IDENTITY_OPTIONS,
         | 
| 35 | 
            +
            )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            load_models(models, 10)
         | 
| 39 | 
            +
            #load_models(models, 20) # Fetching 20 models at the same time. default: 5 *This option is not working so far.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            css = """
         | 
| 43 | 
            +
            #model_info { text-align: center; display:block; }
         | 
| 44 | 
            +
            """
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
         | 
| 47 | 
            +
                with gr.Column(): 
         | 
| 48 | 
            +
                    with gr.Accordion("Advanced settings", open=True):
         | 
| 49 | 
            +
                        with gr.Accordion("Recommended Prompt", open=False):
         | 
| 50 | 
            +
                            recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
         | 
| 51 | 
            +
                            with gr.Row():
         | 
| 52 | 
            +
                                positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
         | 
| 53 | 
            +
                                positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
         | 
| 54 | 
            +
                                negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
         | 
| 55 | 
            +
                                negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
         | 
| 56 | 
            +
                        with gr.Accordion("Prompt Transformer", open=False):
         | 
| 57 | 
            +
                            v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
         | 
| 58 | 
            +
                            v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
         | 
| 59 | 
            +
                            v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
         | 
| 60 | 
            +
                            v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")                    
         | 
| 61 | 
            +
                            v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
         | 
| 62 | 
            +
                            v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
         | 
| 63 | 
            +
                        with gr.Accordion("Model", open=True):
         | 
| 64 | 
            +
                            model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0])
         | 
| 65 | 
            +
                            model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_id="model_info")
         | 
| 66 | 
            +
                    with gr.Group():
         | 
| 67 | 
            +
                        with gr.Accordion("Prompt from Image File", open=False):
         | 
| 68 | 
            +
                            tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
         | 
| 69 | 
            +
                            with gr.Accordion(label="Advanced options", open=False):
         | 
| 70 | 
            +
                                tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
         | 
| 71 | 
            +
                                tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
         | 
| 72 | 
            +
                                tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
         | 
| 73 | 
            +
                                tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)  
         | 
| 74 | 
            +
                                tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
         | 
| 75 | 
            +
                            tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
         | 
| 76 | 
            +
                            tagger_generate_from_image = gr.Button(value="Generate Tags from Image")
         | 
| 77 | 
            +
                        with gr.Row():
         | 
| 78 | 
            +
                            v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
         | 
| 79 | 
            +
                            v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
         | 
| 80 | 
            +
                            random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
         | 
| 81 | 
            +
                            clear_prompt = gr.Button(value="Clear Prompt 🗑️", size="sm", scale=1)
         | 
| 82 | 
            +
                        prompt = gr.Text(label="Prompt", lines=1, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
         | 
| 83 | 
            +
                        neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
         | 
| 84 | 
            +
                    with gr.Row():
         | 
| 85 | 
            +
                        run_button = gr.Button("Generate Image", scale=6)
         | 
| 86 | 
            +
                        random_button = gr.Button("Random Model 🎲", scale=3)
         | 
| 87 | 
            +
                        image_num = gr.Number(label="Count", minimum=1, maximum=16, value=1, step=1, interactive=True, scale=1)
         | 
| 88 | 
            +
                    results = gr.Gallery(label="Gallery", interactive=False, show_download_button=True, show_share_button=False,
         | 
| 89 | 
            +
                                          container=True, format="png", object_fit="contain")
         | 
| 90 | 
            +
                    image_files = gr.Files(label="Download", interactive=False)
         | 
| 91 | 
            +
                    clear_results = gr.Button("Clear Gallery / Download")
         | 
| 92 | 
            +
                examples = gr.Examples(
         | 
| 93 | 
            +
                    examples = [
         | 
| 94 | 
            +
                        ["souryuu asuka langley, 1girl, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors"],
         | 
| 95 | 
            +
                        ["sailor moon, magical girl transformation, sparkles and ribbons, soft pastel colors, crescent moon motif, starry night sky background, shoujo manga style"],
         | 
| 96 | 
            +
                        ["kafuu chino, 1girl, solo"],
         | 
| 97 | 
            +
                        ["1girl"],
         | 
| 98 | 
            +
                        ["beautiful sunset"],
         | 
| 99 | 
            +
                    ],
         | 
| 100 | 
            +
                    inputs=[prompt],
         | 
| 101 | 
            +
                )
         | 
| 102 | 
            +
                gr.Markdown(
         | 
| 103 | 
            +
                    f"""This demo was created in reference to the following demos.
         | 
| 104 | 
            +
            - [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood).
         | 
| 105 | 
            +
            - [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL).
         | 
| 106 | 
            +
            <br>The first startup takes a mind-boggling amount of time, but not so much after the second.
         | 
| 107 | 
            +
            This is due to the time it takes for Gradio to generate an example image to cache.
         | 
| 108 | 
            +
                        """
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
                gr.DuplicateButton(value="Duplicate Space")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)
         | 
| 113 | 
            +
                gr.on(
         | 
| 114 | 
            +
                    triggers=[run_button.click, prompt.submit],
         | 
| 115 | 
            +
                    fn=infer_multi,
         | 
| 116 | 
            +
                    inputs=[prompt, neg_prompt, results, image_num, model_name,
         | 
| 117 | 
            +
                             positive_prefix, positive_suffix, negative_prefix, negative_suffix],
         | 
| 118 | 
            +
                    outputs=[results],
         | 
| 119 | 
            +
                    queue=True,
         | 
| 120 | 
            +
                    show_progress="full",
         | 
| 121 | 
            +
                    show_api=True,
         | 
| 122 | 
            +
                ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
         | 
| 123 | 
            +
                gr.on(
         | 
| 124 | 
            +
                    triggers=[random_button.click],
         | 
| 125 | 
            +
                    fn=infer_multi_random,
         | 
| 126 | 
            +
                    inputs=[prompt, neg_prompt, results, image_num,
         | 
| 127 | 
            +
                             positive_prefix, positive_suffix, negative_prefix, negative_suffix],
         | 
| 128 | 
            +
                    outputs=[results],
         | 
| 129 | 
            +
                    queue=True,
         | 
| 130 | 
            +
                    show_progress="full",
         | 
| 131 | 
            +
                    show_api=True,
         | 
| 132 | 
            +
                ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
         | 
| 133 | 
            +
                clear_prompt.click(lambda: (None, None, None), None, [prompt, v2_series, v2_character], queue=False, show_api=False)
         | 
| 134 | 
            +
                clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
         | 
| 135 | 
            +
                recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
         | 
| 136 | 
            +
                 [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
         | 
| 137 | 
            +
                random_prompt.click(v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
         | 
| 138 | 
            +
                                                       v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], queue=False, show_api=False)
         | 
| 139 | 
            +
                tagger_generate_from_image.click(
         | 
| 140 | 
            +
                    predict_tags_wd,
         | 
| 141 | 
            +
                    [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
         | 
| 142 | 
            +
                    [v2_series, v2_character, prompt, gr.Button(visible=False)],
         | 
| 143 | 
            +
                    show_api=False,
         | 
| 144 | 
            +
                ).success(
         | 
| 145 | 
            +
                    predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
         | 
| 146 | 
            +
                ).success(
         | 
| 147 | 
            +
                    remove_specific_prompt, [prompt, tagger_keep_tags], [prompt], queue=False, show_api=False,
         | 
| 148 | 
            +
                ).success(
         | 
| 149 | 
            +
                    convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
         | 
| 150 | 
            +
                ).success(
         | 
| 151 | 
            +
                    insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
         | 
| 152 | 
            +
                )
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            demo.queue()
         | 
| 155 | 
            +
            demo.launch()
         | 
    	
        model.py
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from multit2i import find_model_list
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            models = [
         | 
| 5 | 
            +
                'yodayo-ai/kivotos-xl-2.0',
         | 
| 6 | 
            +
                'yodayo-ai/holodayo-xl-2.1',
         | 
| 7 | 
            +
                'cagliostrolab/animagine-xl-3.1',
         | 
| 8 | 
            +
                'votepurchase/ponyDiffusionV6XL',
         | 
| 9 | 
            +
                'eienmojiki/Anything-XL',
         | 
| 10 | 
            +
                'eienmojiki/Starry-XL-v5.2',
         | 
| 11 | 
            +
                'digiplay/majicMIX_sombre_v2',
         | 
| 12 | 
            +
                'digiplay/majicMIX_realistic_v7',
         | 
| 13 | 
            +
                'votepurchase/counterfeitV30_v30',
         | 
| 14 | 
            +
                'Meina/MeinaMix_V11',
         | 
| 15 | 
            +
                'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
         | 
| 16 | 
            +
                'kayfahaarukku/UrangDiffusion-1.1',
         | 
| 17 | 
            +
                'Raelina/Rae-Diffusion-XL-V2',
         | 
| 18 | 
            +
                'Raelina/Raemu-XL-V4',
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            models = ['yodayo-ai/kivotos-xl-2.0', 'Raelina/Rae-Diffusion-XL-V2']
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Examples:
         | 
| 26 | 
            +
            #models = ['yodayo-ai/kivotos-xl-2.0', 'yodayo-ai/holodayo-xl-2.1'] # specific models
         | 
| 27 | 
            +
            #models = find_model_list("John6666", [], "", "last_modified", 20) # John6666's latest 20 models
         | 
| 28 | 
            +
            #models = find_model_list("John6666", ["anime"], "", "last_modified", 20) # John6666's latest 20 models with 'anime' tag
         | 
| 29 | 
            +
            #models = find_model_list("John6666", [], "anime", "last_modified", 20) # John6666's latest 20 models without 'anime' tag
         | 
| 30 | 
            +
            #models = find_model_list("", [], "", "last_modified", 20) # latest 20 text-to-image models of huggingface
         | 
| 31 | 
            +
            #models = find_model_list("", [], "", "downloads", 20) # monthly most downloaded 20 text-to-image models of huggingface
         | 
| 32 | 
            +
             | 
    	
        multit2i.py
    ADDED
    
    | @@ -0,0 +1,311 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import asyncio
         | 
| 3 | 
            +
            from threading import RLock, Thread
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            lock = RLock()
         | 
| 8 | 
            +
            loaded_models = {}
         | 
| 9 | 
            +
            model_info_dict = {}
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def to_list(s):
         | 
| 13 | 
            +
                return [x.strip() for x in s.split(",")]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def list_sub(a, b):
         | 
| 17 | 
            +
                return [e for e in a if e not in b]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def list_uniq(l):
         | 
| 21 | 
            +
                    return sorted(set(l), key=l.index)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def is_repo_name(s):
         | 
| 25 | 
            +
                import re
         | 
| 26 | 
            +
                return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30):
         | 
| 30 | 
            +
                from huggingface_hub import HfApi
         | 
| 31 | 
            +
                api = HfApi()
         | 
| 32 | 
            +
                default_tags = ["diffusers"]
         | 
| 33 | 
            +
                if not sort: sort = "last_modified"
         | 
| 34 | 
            +
                models = []
         | 
| 35 | 
            +
                try:
         | 
| 36 | 
            +
                    model_infos = api.list_models(author=author, pipeline_tag="text-to-image",
         | 
| 37 | 
            +
                                                   tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
         | 
| 38 | 
            +
                except Exception as e:
         | 
| 39 | 
            +
                    print(f"Error: Failed to list models.")
         | 
| 40 | 
            +
                    print(e)
         | 
| 41 | 
            +
                    return models
         | 
| 42 | 
            +
                for model in model_infos:
         | 
| 43 | 
            +
                    if not model.private and not model.gated:
         | 
| 44 | 
            +
                       if not_tag and not_tag in model.tags: continue
         | 
| 45 | 
            +
                       models.append(model.id)
         | 
| 46 | 
            +
                       if len(models) == limit: break
         | 
| 47 | 
            +
                return models
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def get_t2i_model_info_dict(repo_id: str):
         | 
| 51 | 
            +
                from huggingface_hub import HfApi
         | 
| 52 | 
            +
                api = HfApi()
         | 
| 53 | 
            +
                info = {"md": "None"}
         | 
| 54 | 
            +
                try:
         | 
| 55 | 
            +
                    if not is_repo_name(repo_id) or not api.repo_exists(repo_id=repo_id): return info
         | 
| 56 | 
            +
                    model = api.model_info(repo_id=repo_id)
         | 
| 57 | 
            +
                except Exception as e:
         | 
| 58 | 
            +
                    print(f"Error: Failed to get {repo_id}'s info.")
         | 
| 59 | 
            +
                    print(e)
         | 
| 60 | 
            +
                    return info
         | 
| 61 | 
            +
                if model.private or model.gated: return info
         | 
| 62 | 
            +
                try:
         | 
| 63 | 
            +
                    tags = model.tags
         | 
| 64 | 
            +
                except Exception as e:
         | 
| 65 | 
            +
                    print(e)
         | 
| 66 | 
            +
                    return info
         | 
| 67 | 
            +
                if not 'diffusers' in model.tags: return info
         | 
| 68 | 
            +
                if 'diffusers:StableDiffusionXLPipeline' in tags: info["ver"] = "SDXL"
         | 
| 69 | 
            +
                elif 'diffusers:StableDiffusionPipeline' in tags: info["ver"] = "SD1.5"
         | 
| 70 | 
            +
                elif 'diffusers:StableDiffusion3Pipeline' in tags: info["ver"] = "SD3"
         | 
| 71 | 
            +
                else: info["ver"] = "Other"
         | 
| 72 | 
            +
                info["url"] = f"https://huggingface.co/{repo_id}/"
         | 
| 73 | 
            +
                if model.card_data and model.card_data.tags:
         | 
| 74 | 
            +
                    info["tags"] = model.card_data.tags
         | 
| 75 | 
            +
                info["downloads"] = model.downloads
         | 
| 76 | 
            +
                info["likes"] = model.likes
         | 
| 77 | 
            +
                info["last_modified"] = model.last_modified.strftime("lastmod: %Y-%m-%d")
         | 
| 78 | 
            +
                un_tags = ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']
         | 
| 79 | 
            +
                descs = [info["ver"]] + list_sub(info["tags"], un_tags) + [f'DLs: {info["downloads"]}'] + [f'❤: {info["likes"]}'] + [info["last_modified"]]
         | 
| 80 | 
            +
                info["md"] = f'Model Info: {", ".join(descs)} [Model Repo]({info["url"]})'
         | 
| 81 | 
            +
                return info
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
         | 
| 85 | 
            +
                from datetime import datetime, timezone, timedelta
         | 
| 86 | 
            +
                progress(0, desc="Updating gallery...")
         | 
| 87 | 
            +
                dt_now = datetime.now(timezone(timedelta(hours=9)))
         | 
| 88 | 
            +
                basename = dt_now.strftime('%Y%m%d_%H%M%S_')
         | 
| 89 | 
            +
                i = 1
         | 
| 90 | 
            +
                if not images: return images
         | 
| 91 | 
            +
                output_images = []
         | 
| 92 | 
            +
                output_paths = []
         | 
| 93 | 
            +
                for image in images:
         | 
| 94 | 
            +
                    filename = f'{image[1]}_{basename}{str(i)}.png'
         | 
| 95 | 
            +
                    i += 1
         | 
| 96 | 
            +
                    oldpath = Path(image[0])
         | 
| 97 | 
            +
                    newpath = oldpath
         | 
| 98 | 
            +
                    try:
         | 
| 99 | 
            +
                        if oldpath.stem == "image" and oldpath.exists():
         | 
| 100 | 
            +
                            newpath = oldpath.resolve().rename(Path(filename).resolve())
         | 
| 101 | 
            +
                    except Exception as e:
         | 
| 102 | 
            +
                       print(e)
         | 
| 103 | 
            +
                       pass
         | 
| 104 | 
            +
                    finally:
         | 
| 105 | 
            +
                        output_paths.append(str(newpath))
         | 
| 106 | 
            +
                        output_images.append((str(newpath), str(filename)))
         | 
| 107 | 
            +
                progress(1, desc="Gallery updated.")
         | 
| 108 | 
            +
                return gr.update(value=output_images), gr.update(value=output_paths)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def load_model(model_name: str):
         | 
| 112 | 
            +
                global loaded_models
         | 
| 113 | 
            +
                global model_info_dict
         | 
| 114 | 
            +
                if model_name in loaded_models.keys(): return loaded_models[model_name]
         | 
| 115 | 
            +
                try:
         | 
| 116 | 
            +
                    with lock:
         | 
| 117 | 
            +
                        loaded_models[model_name] = gr.load(f'models/{model_name}')
         | 
| 118 | 
            +
                    print(f"Loaded: {model_name}")
         | 
| 119 | 
            +
                except Exception as e:
         | 
| 120 | 
            +
                    with lock:
         | 
| 121 | 
            +
                        if model_name in loaded_models.keys(): del loaded_models[model_name]
         | 
| 122 | 
            +
                    print(f"Failed to load: {model_name}")
         | 
| 123 | 
            +
                    print(e)
         | 
| 124 | 
            +
                    return None
         | 
| 125 | 
            +
                try:
         | 
| 126 | 
            +
                    with lock:
         | 
| 127 | 
            +
                        model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
         | 
| 128 | 
            +
                except Exception as e:
         | 
| 129 | 
            +
                    with lock:
         | 
| 130 | 
            +
                        if model_name in model_info_dict.keys(): del model_info_dict[model_name]
         | 
| 131 | 
            +
                    print(e)
         | 
| 132 | 
            +
                return loaded_models[model_name]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            async def async_load_models(models: list, limit: int=5, wait=10):
         | 
| 136 | 
            +
                sem = asyncio.Semaphore(limit)
         | 
| 137 | 
            +
                async def async_load_model(model: str):
         | 
| 138 | 
            +
                    async with sem:
         | 
| 139 | 
            +
                       try:
         | 
| 140 | 
            +
                           return await asyncio.to_thread(load_model, model)
         | 
| 141 | 
            +
                       except Exception as e:
         | 
| 142 | 
            +
                           print(e)
         | 
| 143 | 
            +
                tasks = [asyncio.create_task(async_load_model(model)) for model in models]
         | 
| 144 | 
            +
                return await asyncio.gather(*tasks, return_exceptions=True)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def load_models(models: list, limit: int=5):
         | 
| 148 | 
            +
                loop = asyncio.new_event_loop()
         | 
| 149 | 
            +
                try:
         | 
| 150 | 
            +
                    loop.run_until_complete(async_load_models(models, limit))
         | 
| 151 | 
            +
                except Exception as e:
         | 
| 152 | 
            +
                    print(e)
         | 
| 153 | 
            +
                    pass
         | 
| 154 | 
            +
                finally:
         | 
| 155 | 
            +
                    loop.close()
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            positive_prefix = {
         | 
| 159 | 
            +
                "Pony": to_list("score_9, score_8_up, score_7_up"),
         | 
| 160 | 
            +
                "Pony Anime": to_list("source_anime, anime, score_9, score_8_up, score_7_up"),
         | 
| 161 | 
            +
            }
         | 
| 162 | 
            +
            positive_suffix = {
         | 
| 163 | 
            +
                "Common": to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres"),
         | 
| 164 | 
            +
                "Anime": to_list("anime artwork, anime style, studio anime, highly detailed"),
         | 
| 165 | 
            +
            }
         | 
| 166 | 
            +
            negative_prefix = {
         | 
| 167 | 
            +
                "Pony": to_list("score_6, score_5, score_4"),
         | 
| 168 | 
            +
                "Pony Anime": to_list("score_6, score_5, score_4, source_pony, source_furry, source_cartoon"),
         | 
| 169 | 
            +
                "Pony Real": to_list("score_6, score_5, score_4, source_anime, source_pony, source_furry, source_cartoon"),
         | 
| 170 | 
            +
            }
         | 
| 171 | 
            +
            negative_suffix = {
         | 
| 172 | 
            +
                "Common": to_list("lowres, (bad), bad hands, bad feet, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"),
         | 
| 173 | 
            +
                "Pony Anime": to_list("busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends"),
         | 
| 174 | 
            +
                "Pony Real": to_list("ugly, airbrushed, simple background, cgi, cartoon, anime"),
         | 
| 175 | 
            +
            }
         | 
| 176 | 
            +
            positive_all = negative_all = []
         | 
| 177 | 
            +
            for k, v in (positive_prefix | positive_suffix).items():
         | 
| 178 | 
            +
                positive_all = positive_all + v + [s.replace("_", " ") for s in v]
         | 
| 179 | 
            +
            positive_all = list_uniq(positive_all)
         | 
| 180 | 
            +
            for k, v in (negative_prefix | negative_suffix).items():
         | 
| 181 | 
            +
                negative_all = negative_all + v + [s.replace("_", " ") for s in v]
         | 
| 182 | 
            +
            positive_all = list_uniq(positive_all)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def recom_prompt(prompt: str = "", neg_prompt: str = "", pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
         | 
| 186 | 
            +
                def flatten(src):
         | 
| 187 | 
            +
                    return [item for row in src for item in row]
         | 
| 188 | 
            +
                prompts = to_list(prompt)
         | 
| 189 | 
            +
                neg_prompts = to_list(neg_prompt)
         | 
| 190 | 
            +
                prompts = list_sub(prompts, positive_all)
         | 
| 191 | 
            +
                neg_prompts = list_sub(neg_prompts, negative_all)
         | 
| 192 | 
            +
                last_empty_p = [""] if not prompts and type != "None" else []
         | 
| 193 | 
            +
                last_empty_np = [""] if not neg_prompts and type != "None" else []
         | 
| 194 | 
            +
                prefix_ps = flatten([positive_prefix.get(s, []) for s in pos_pre])
         | 
| 195 | 
            +
                suffix_ps = flatten([positive_suffix.get(s, []) for s in pos_suf])
         | 
| 196 | 
            +
                prefix_nps = flatten([negative_prefix.get(s, []) for s in neg_pre])
         | 
| 197 | 
            +
                suffix_nps = flatten([negative_suffix.get(s, []) for s in neg_suf])
         | 
| 198 | 
            +
                prompt = ", ".join(list_uniq(prefix_ps + prompts + suffix_ps) + last_empty_p)
         | 
| 199 | 
            +
                neg_prompt = ", ".join(list_uniq(prefix_nps + neg_prompts + suffix_nps) + last_empty_np)
         | 
| 200 | 
            +
                return prompt, neg_prompt
         | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            recom_prompt_type = {
         | 
| 204 | 
            +
                "None": ([], [], [], []),
         | 
| 205 | 
            +
                "Auto": ([], [], [], []),
         | 
| 206 | 
            +
                "Common": ([], ["Common"], [], ["Common"]),
         | 
| 207 | 
            +
                "Animagine": ([], ["Common", "Anime"], [], ["Common"]),
         | 
| 208 | 
            +
                "Pony": (["Pony"], ["Common"], ["Pony"], ["Common"]),
         | 
| 209 | 
            +
                "Pony Anime": (["Pony", "Pony Anime"], ["Common", "Anime"], ["Pony", "Pony Anime"], ["Common", "Pony Anime"]),
         | 
| 210 | 
            +
                "Pony Real": (["Pony"], ["Common"], ["Pony", "Pony Real"], ["Common", "Pony Real"]),
         | 
| 211 | 
            +
            }
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
            enable_auto_recom_prompt = False
         | 
| 215 | 
            +
            def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
         | 
| 216 | 
            +
                global enable_auto_recom_prompt
         | 
| 217 | 
            +
                if type == "Auto":  enable_auto_recom_prompt = True
         | 
| 218 | 
            +
                else: enable_auto_recom_prompt = False
         | 
| 219 | 
            +
                pos_pre, pos_suf, neg_pre, neg_suf = recom_prompt_type.get(type, ([], [], [], []))
         | 
| 220 | 
            +
                return recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            def set_recom_prompt_preset(type: str = "None"):
         | 
| 224 | 
            +
                pos_pre, pos_suf, neg_pre, neg_suf = recom_prompt_type.get(type, ([], [], [], []))
         | 
| 225 | 
            +
                return pos_pre, pos_suf, neg_pre, neg_suf
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def get_recom_prompt_type():
         | 
| 229 | 
            +
                type = list(recom_prompt_type.keys())
         | 
| 230 | 
            +
                type.remove("Auto")
         | 
| 231 | 
            +
                return type
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            def get_positive_prefix():
         | 
| 235 | 
            +
                return list(positive_prefix.keys())
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            def get_positive_suffix():
         | 
| 239 | 
            +
                return list(positive_suffix.keys())
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def get_negative_prefix():
         | 
| 243 | 
            +
                return list(negative_prefix.keys())
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            def get_negative_suffix():
         | 
| 247 | 
            +
                return list(negative_suffix.keys())
         | 
| 248 | 
            +
             | 
| 249 | 
            +
             | 
| 250 | 
            +
            def get_model_info_md(model_name: str):
         | 
| 251 | 
            +
                if model_name in model_info_dict.keys(): return model_info_dict[model_name].get("md", "")
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
            def change_model(model_name: str):
         | 
| 255 | 
            +
                load_model(model_name)
         | 
| 256 | 
            +
                return get_model_info_md(model_name)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            def infer(prompt: str, neg_prompt: str, model_name: str):
         | 
| 260 | 
            +
                from PIL import Image
         | 
| 261 | 
            +
                import random
         | 
| 262 | 
            +
                seed = ""
         | 
| 263 | 
            +
                rand = random.randint(1, 500)
         | 
| 264 | 
            +
                for i in range(rand):
         | 
| 265 | 
            +
                    seed += " "
         | 
| 266 | 
            +
                caption = model_name.split("/")[-1]
         | 
| 267 | 
            +
                try:
         | 
| 268 | 
            +
                    model = load_model(model_name)
         | 
| 269 | 
            +
                    if not model: return (Image.Image(), None)
         | 
| 270 | 
            +
                    image_path = model(prompt + seed)
         | 
| 271 | 
            +
                    image = Image.open(image_path).convert('RGBA')
         | 
| 272 | 
            +
                except Exception as e:
         | 
| 273 | 
            +
                    print(e)
         | 
| 274 | 
            +
                    return (Image.Image(), None)
         | 
| 275 | 
            +
                return (image, caption)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
             | 
| 278 | 
            +
            async def infer_multi(prompt: str, neg_prompt: str, results: list, image_num: float, model_name: str,
         | 
| 279 | 
            +
                             pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
         | 
| 280 | 
            +
                #from tqdm.asyncio import tqdm_asyncio
         | 
| 281 | 
            +
                image_num = int(image_num)
         | 
| 282 | 
            +
                images = results if results else []
         | 
| 283 | 
            +
                prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
         | 
| 284 | 
            +
                tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for i in range(image_num)]
         | 
| 285 | 
            +
                results = await asyncio.gather(*tasks, return_exceptions=True)
         | 
| 286 | 
            +
                #results = await tqdm_asyncio.gather(*tasks)
         | 
| 287 | 
            +
                if not results: results = []
         | 
| 288 | 
            +
                for result in results:
         | 
| 289 | 
            +
                    with lock:
         | 
| 290 | 
            +
                        if result and result[1]: images.append(result)
         | 
| 291 | 
            +
                    yield images
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
            async def infer_multi_random(prompt: str, neg_prompt: str, results: list, image_num: float, 
         | 
| 295 | 
            +
                             pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
         | 
| 296 | 
            +
                #from tqdm.asyncio import tqdm_asyncio
         | 
| 297 | 
            +
                import random
         | 
| 298 | 
            +
                image_num = int(image_num)
         | 
| 299 | 
            +
                images = results if results else []
         | 
| 300 | 
            +
                random.seed()
         | 
| 301 | 
            +
                model_names = random.choices(list(loaded_models.keys()), k = image_num)
         | 
| 302 | 
            +
                prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
         | 
| 303 | 
            +
                tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for model_name in model_names]
         | 
| 304 | 
            +
                results = await asyncio.gather(*tasks, return_exceptions=True)
         | 
| 305 | 
            +
                #await tqdm_asyncio.gather(*tasks)
         | 
| 306 | 
            +
                if not results: results = []
         | 
| 307 | 
            +
                for result in results:
         | 
| 308 | 
            +
                    with lock:
         | 
| 309 | 
            +
                        if result and result[1]: images.append(result)
         | 
| 310 | 
            +
                    yield images
         | 
| 311 | 
            +
             | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            huggingface_hub
         | 
| 2 | 
            +
            torch
         | 
| 3 | 
            +
            torchvision
         | 
| 4 | 
            +
            accelerate
         | 
| 5 | 
            +
            transformers
         | 
| 6 | 
            +
            optimum[onnxruntime]
         | 
| 7 | 
            +
            spaces
         | 
| 8 | 
            +
            dartrs
         | 
| 9 | 
            +
            httpx==0.13.3
         | 
| 10 | 
            +
            httpcore
         | 
| 11 | 
            +
            googletrans==4.0.0rc1
         | 
| 12 | 
            +
            timm
         | 
    	
        tagger/character_series_dict.csv
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tagger/danbooru_e621.csv
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tagger/fl2sd3longcap.py
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import AutoProcessor, AutoModelForCausalLM
         | 
| 2 | 
            +
            import spaces
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            from PIL import Image 
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import subprocess
         | 
| 7 | 
            +
            subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval()
         | 
| 10 | 
            +
            fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def fl_modify_caption(caption: str) -> str:
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                Removes specific prefixes from captions if present, otherwise returns the original caption.
         | 
| 16 | 
            +
                Args:
         | 
| 17 | 
            +
                    caption (str): A string containing a caption.
         | 
| 18 | 
            +
                Returns:
         | 
| 19 | 
            +
                    str: The caption with the prefix removed if it was present, or the original caption.
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                # Define the prefixes to remove
         | 
| 22 | 
            +
                prefix_substrings = [
         | 
| 23 | 
            +
                    ('captured from ', ''),
         | 
| 24 | 
            +
                    ('captured at ', '')
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                # Create a regex pattern to match any of the prefixes
         | 
| 28 | 
            +
                pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
         | 
| 29 | 
            +
                replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                # Function to replace matched prefix with its corresponding replacement
         | 
| 32 | 
            +
                def replace_fn(match):
         | 
| 33 | 
            +
                    return replacers[match.group(0).lower()]
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                # Apply the regex to the caption
         | 
| 36 | 
            +
                modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                # If the caption was modified, return the modified version; otherwise, return the original
         | 
| 39 | 
            +
                return modified_caption if modified_caption != caption else caption
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            @spaces.GPU
         | 
| 43 | 
            +
            def fl_run_example(image):
         | 
| 44 | 
            +
                task_prompt = "<DESCRIPTION>"
         | 
| 45 | 
            +
                prompt = task_prompt + "Describe this image in great detail."
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                # Ensure the image is in RGB mode
         | 
| 48 | 
            +
                if image.mode != "RGB":
         | 
| 49 | 
            +
                    image = image.convert("RGB")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
         | 
| 52 | 
            +
                generated_ids = fl_model.generate(
         | 
| 53 | 
            +
                    input_ids=inputs["input_ids"],
         | 
| 54 | 
            +
                    pixel_values=inputs["pixel_values"],
         | 
| 55 | 
            +
                    max_new_tokens=1024,
         | 
| 56 | 
            +
                    num_beams=3
         | 
| 57 | 
            +
                )
         | 
| 58 | 
            +
                generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
         | 
| 59 | 
            +
                parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
         | 
| 60 | 
            +
                return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def predict_tags_fl2_sd3(image: Image.Image, input_tags: str, algo: list[str]):
         | 
| 64 | 
            +
                def to_list(s):
         | 
| 65 | 
            +
                    return [x.strip() for x in s.split(",") if not s == ""]
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                def list_uniq(l):
         | 
| 68 | 
            +
                    return sorted(set(l), key=l.index)
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                if not "Use Florence-2-SD3-Long-Captioner" in algo:
         | 
| 71 | 
            +
                    return input_tags
         | 
| 72 | 
            +
                tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image) + ", "))
         | 
| 73 | 
            +
                tag_list.remove("")
         | 
| 74 | 
            +
                return ", ".join(tag_list)
         | 
    	
        tagger/output.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            @dataclass
         | 
| 5 | 
            +
            class UpsamplingOutput:
         | 
| 6 | 
            +
                upsampled_tags: str
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                copyright_tags: str
         | 
| 9 | 
            +
                character_tags: str
         | 
| 10 | 
            +
                general_tags: str
         | 
| 11 | 
            +
                rating_tag: str
         | 
| 12 | 
            +
                aspect_ratio_tag: str
         | 
| 13 | 
            +
                length_tag: str
         | 
| 14 | 
            +
                identity_tag: str
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                elapsed_time: float = 0.0
         | 
    	
        tagger/tag_group.csv
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tagger/tagger.py
    ADDED
    
    | @@ -0,0 +1,546 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            import spaces
         | 
| 5 | 
            +
            from transformers import (
         | 
| 6 | 
            +
                AutoImageProcessor,
         | 
| 7 | 
            +
                AutoModelForImageClassification,
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
            from pathlib import Path
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
         | 
| 13 | 
            +
            WD_MODEL_NAME = WD_MODEL_NAMES[0]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
         | 
| 16 | 
            +
            wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 17 | 
            +
            wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
         | 
| 21 | 
            +
                return (
         | 
| 22 | 
            +
                    [f"1{noun}"]
         | 
| 23 | 
            +
                    + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
         | 
| 24 | 
            +
                    + [f"{maximum+1}+{noun}s"]
         | 
| 25 | 
            +
                )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            PEOPLE_TAGS = (
         | 
| 29 | 
            +
                _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            RATING_MAP = {
         | 
| 34 | 
            +
                "general": "safe",
         | 
| 35 | 
            +
                "sensitive": "sensitive",
         | 
| 36 | 
            +
                "questionable": "nsfw",
         | 
| 37 | 
            +
                "explicit": "explicit, nsfw",
         | 
| 38 | 
            +
            }
         | 
| 39 | 
            +
            DANBOORU_TO_E621_RATING_MAP = {
         | 
| 40 | 
            +
                "safe": "rating_safe",
         | 
| 41 | 
            +
                "sensitive": "rating_safe",
         | 
| 42 | 
            +
                "nsfw": "rating_explicit",
         | 
| 43 | 
            +
                "explicit, nsfw": "rating_explicit",
         | 
| 44 | 
            +
                "explicit": "rating_explicit",
         | 
| 45 | 
            +
                "rating:safe": "rating_safe",
         | 
| 46 | 
            +
                "rating:general": "rating_safe",
         | 
| 47 | 
            +
                "rating:sensitive": "rating_safe",
         | 
| 48 | 
            +
                "rating:questionable, nsfw": "rating_explicit",
         | 
| 49 | 
            +
                "rating:explicit, nsfw": "rating_explicit",
         | 
| 50 | 
            +
            }
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
         | 
| 54 | 
            +
            kaomojis = [
         | 
| 55 | 
            +
                "0_0",
         | 
| 56 | 
            +
                "(o)_(o)",
         | 
| 57 | 
            +
                "+_+",
         | 
| 58 | 
            +
                "+_-",
         | 
| 59 | 
            +
                "._.",
         | 
| 60 | 
            +
                "<o>_<o>",
         | 
| 61 | 
            +
                "<|>_<|>",
         | 
| 62 | 
            +
                "=_=",
         | 
| 63 | 
            +
                ">_<",
         | 
| 64 | 
            +
                "3_3",
         | 
| 65 | 
            +
                "6_9",
         | 
| 66 | 
            +
                ">_o",
         | 
| 67 | 
            +
                "@_@",
         | 
| 68 | 
            +
                "^_^",
         | 
| 69 | 
            +
                "o_o",
         | 
| 70 | 
            +
                "u_u",
         | 
| 71 | 
            +
                "x_x",
         | 
| 72 | 
            +
                "|_|",
         | 
| 73 | 
            +
                "||_||",
         | 
| 74 | 
            +
            ]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def replace_underline(x: str):
         | 
| 78 | 
            +
                return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def to_list(s):
         | 
| 82 | 
            +
                return [x.strip() for x in s.split(",") if not s == ""]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def list_sub(a, b):
         | 
| 86 | 
            +
                return [e for e in a if e not in b]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def list_uniq(l):
         | 
| 90 | 
            +
                return sorted(set(l), key=l.index)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def load_dict_from_csv(filename):
         | 
| 94 | 
            +
                dict = {}
         | 
| 95 | 
            +
                if not Path(filename).exists():
         | 
| 96 | 
            +
                    if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
         | 
| 97 | 
            +
                    else: return dict
         | 
| 98 | 
            +
                try:
         | 
| 99 | 
            +
                    with open(filename, 'r', encoding="utf-8") as f:
         | 
| 100 | 
            +
                        lines = f.readlines()
         | 
| 101 | 
            +
                except Exception:
         | 
| 102 | 
            +
                    print(f"Failed to open dictionary file: {filename}")
         | 
| 103 | 
            +
                    return dict
         | 
| 104 | 
            +
                for line in lines:
         | 
| 105 | 
            +
                    parts = line.strip().split(',')
         | 
| 106 | 
            +
                    dict[parts[0]] = parts[1]
         | 
| 107 | 
            +
                return dict
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            anime_series_dict = load_dict_from_csv('character_series_dict.csv')
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def character_list_to_series_list(character_list):
         | 
| 114 | 
            +
                output_series_tag = []
         | 
| 115 | 
            +
                series_tag = ""
         | 
| 116 | 
            +
                series_dict = anime_series_dict
         | 
| 117 | 
            +
                for tag in character_list:
         | 
| 118 | 
            +
                    series_tag = series_dict.get(tag, "")
         | 
| 119 | 
            +
                    if tag.endswith(")"):
         | 
| 120 | 
            +
                        tags = tag.split("(")
         | 
| 121 | 
            +
                        character_tag = "(".join(tags[:-1])
         | 
| 122 | 
            +
                        if character_tag.endswith(" "):
         | 
| 123 | 
            +
                            character_tag = character_tag[:-1]
         | 
| 124 | 
            +
                        series_tag = tags[-1].replace(")", "")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                if series_tag:
         | 
| 127 | 
            +
                    output_series_tag.append(series_tag)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                return output_series_tag
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def select_random_character(series: str, character: str):
         | 
| 133 | 
            +
                from random import seed, randrange
         | 
| 134 | 
            +
                seed()
         | 
| 135 | 
            +
                character_list = list(anime_series_dict.keys())
         | 
| 136 | 
            +
                character = character_list[randrange(len(character_list) - 1)]
         | 
| 137 | 
            +
                series = anime_series_dict.get(character.split(",")[0].strip(), "")
         | 
| 138 | 
            +
                return series, character
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def danbooru_to_e621(dtag, e621_dict):
         | 
| 142 | 
            +
                def d_to_e(match, e621_dict):
         | 
| 143 | 
            +
                    dtag = match.group(0)
         | 
| 144 | 
            +
                    etag = e621_dict.get(replace_underline(dtag), "")
         | 
| 145 | 
            +
                    if etag:
         | 
| 146 | 
            +
                        return etag
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        return dtag
         | 
| 149 | 
            +
                
         | 
| 150 | 
            +
                import re
         | 
| 151 | 
            +
                tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
         | 
| 152 | 
            +
                return tag
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
         | 
| 159 | 
            +
                if prompt_type == "danbooru": return input_prompt
         | 
| 160 | 
            +
                tags = input_prompt.split(",") if input_prompt else []
         | 
| 161 | 
            +
                people_tags: list[str] = []
         | 
| 162 | 
            +
                other_tags: list[str] = []
         | 
| 163 | 
            +
                rating_tags: list[str] = []
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                e621_dict = danbooru_to_e621_dict
         | 
| 166 | 
            +
                for tag in tags:
         | 
| 167 | 
            +
                    tag = replace_underline(tag)
         | 
| 168 | 
            +
                    tag = danbooru_to_e621(tag, e621_dict)
         | 
| 169 | 
            +
                    if tag in PEOPLE_TAGS:        
         | 
| 170 | 
            +
                        people_tags.append(tag)
         | 
| 171 | 
            +
                    elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
         | 
| 172 | 
            +
                        rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))            
         | 
| 173 | 
            +
                    else:
         | 
| 174 | 
            +
                        other_tags.append(tag)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                rating_tags = sorted(set(rating_tags), key=rating_tags.index)
         | 
| 177 | 
            +
                rating_tags = [rating_tags[0]] if rating_tags else []
         | 
| 178 | 
            +
                rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                output_prompt = ", ".join(people_tags + other_tags + rating_tags)
         | 
| 181 | 
            +
                
         | 
| 182 | 
            +
                return output_prompt
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def translate_prompt(prompt: str = ""):
         | 
| 186 | 
            +
                def translate_to_english(prompt):
         | 
| 187 | 
            +
                    import httpcore
         | 
| 188 | 
            +
                    setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
         | 
| 189 | 
            +
                    from googletrans import Translator
         | 
| 190 | 
            +
                    translator = Translator()
         | 
| 191 | 
            +
                    try:
         | 
| 192 | 
            +
                        translated_prompt = translator.translate(prompt, src='auto', dest='en').text
         | 
| 193 | 
            +
                        return translated_prompt
         | 
| 194 | 
            +
                    except Exception as e:
         | 
| 195 | 
            +
                        print(e)
         | 
| 196 | 
            +
                        return prompt
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def is_japanese(s):
         | 
| 199 | 
            +
                    import unicodedata
         | 
| 200 | 
            +
                    for ch in s:
         | 
| 201 | 
            +
                        name = unicodedata.name(ch, "") 
         | 
| 202 | 
            +
                        if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
         | 
| 203 | 
            +
                            return True
         | 
| 204 | 
            +
                    return False
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                def to_list(s):
         | 
| 207 | 
            +
                    return [x.strip() for x in s.split(",")]
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                prompts = to_list(prompt)
         | 
| 210 | 
            +
                outputs = []
         | 
| 211 | 
            +
                for p in prompts:
         | 
| 212 | 
            +
                    p = translate_to_english(p) if is_japanese(p) else p
         | 
| 213 | 
            +
                    outputs.append(p)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                return ", ".join(outputs)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
            def translate_prompt_to_ja(prompt: str = ""):
         | 
| 219 | 
            +
                def translate_to_japanese(prompt):
         | 
| 220 | 
            +
                    import httpcore
         | 
| 221 | 
            +
                    setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
         | 
| 222 | 
            +
                    from googletrans import Translator
         | 
| 223 | 
            +
                    translator = Translator()
         | 
| 224 | 
            +
                    try:
         | 
| 225 | 
            +
                        translated_prompt = translator.translate(prompt, src='en', dest='ja').text
         | 
| 226 | 
            +
                        return translated_prompt
         | 
| 227 | 
            +
                    except Exception as e:
         | 
| 228 | 
            +
                        print(e)
         | 
| 229 | 
            +
                        return prompt
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def is_japanese(s):
         | 
| 232 | 
            +
                    import unicodedata
         | 
| 233 | 
            +
                    for ch in s:
         | 
| 234 | 
            +
                        name = unicodedata.name(ch, "") 
         | 
| 235 | 
            +
                        if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
         | 
| 236 | 
            +
                            return True
         | 
| 237 | 
            +
                    return False
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def to_list(s):
         | 
| 240 | 
            +
                    return [x.strip() for x in s.split(",")]
         | 
| 241 | 
            +
                
         | 
| 242 | 
            +
                prompts = to_list(prompt)
         | 
| 243 | 
            +
                outputs = []
         | 
| 244 | 
            +
                for p in prompts:
         | 
| 245 | 
            +
                    p = translate_to_japanese(p) if not is_japanese(p) else p
         | 
| 246 | 
            +
                    outputs.append(p)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                return ", ".join(outputs)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            def tags_to_ja(itag, dict):
         | 
| 252 | 
            +
                def t_to_j(match, dict):
         | 
| 253 | 
            +
                    tag = match.group(0)
         | 
| 254 | 
            +
                    ja = dict.get(replace_underline(tag), "")
         | 
| 255 | 
            +
                    if ja:
         | 
| 256 | 
            +
                        return ja
         | 
| 257 | 
            +
                    else:
         | 
| 258 | 
            +
                        return tag
         | 
| 259 | 
            +
                
         | 
| 260 | 
            +
                import re
         | 
| 261 | 
            +
                tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                return tag
         | 
| 264 | 
            +
             | 
| 265 | 
            +
             | 
| 266 | 
            +
            def convert_tags_to_ja(input_prompt: str = ""):
         | 
| 267 | 
            +
                tags = input_prompt.split(",") if input_prompt else []
         | 
| 268 | 
            +
                out_tags = []
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
         | 
| 271 | 
            +
                dict = tags_to_ja_dict
         | 
| 272 | 
            +
                for tag in tags:
         | 
| 273 | 
            +
                    tag = replace_underline(tag)
         | 
| 274 | 
            +
                    tag = tags_to_ja(tag, dict)
         | 
| 275 | 
            +
                    out_tags.append(tag)
         | 
| 276 | 
            +
                
         | 
| 277 | 
            +
                return ", ".join(out_tags)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            enable_auto_recom_prompt = True
         | 
| 281 | 
            +
             | 
| 282 | 
            +
             | 
| 283 | 
            +
            animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
         | 
| 284 | 
            +
            animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
         | 
| 285 | 
            +
            pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
         | 
| 286 | 
            +
            pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
         | 
| 287 | 
            +
            other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
         | 
| 288 | 
            +
            other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
         | 
| 289 | 
            +
            default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
         | 
| 290 | 
            +
            default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
         | 
| 291 | 
            +
            def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
         | 
| 292 | 
            +
                global enable_auto_recom_prompt
         | 
| 293 | 
            +
                prompts = to_list(prompt)
         | 
| 294 | 
            +
                neg_prompts = to_list(neg_prompt)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                prompts = list_sub(prompts, animagine_ps + pony_ps)
         | 
| 297 | 
            +
                neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                last_empty_p = [""] if not prompts and type != "None" else []
         | 
| 300 | 
            +
                last_empty_np = [""] if not neg_prompts and type != "None" else []
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                if type == "Auto":
         | 
| 303 | 
            +
                    enable_auto_recom_prompt = True
         | 
| 304 | 
            +
                else:
         | 
| 305 | 
            +
                    enable_auto_recom_prompt = False
         | 
| 306 | 
            +
                    if type == "Animagine":
         | 
| 307 | 
            +
                        prompts = prompts + animagine_ps
         | 
| 308 | 
            +
                        neg_prompts = neg_prompts + animagine_nps
         | 
| 309 | 
            +
                    elif type == "Pony":
         | 
| 310 | 
            +
                        prompts = prompts + pony_ps
         | 
| 311 | 
            +
                        neg_prompts = neg_prompts + pony_nps
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                prompt = ", ".join(list_uniq(prompts) + last_empty_p)
         | 
| 314 | 
            +
                neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                return prompt, neg_prompt
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            def load_model_prompt_dict():
         | 
| 320 | 
            +
                import json
         | 
| 321 | 
            +
                dict = {}
         | 
| 322 | 
            +
                path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
         | 
| 323 | 
            +
                try:
         | 
| 324 | 
            +
                    with open('model_dict.json', encoding='utf-8') as f:
         | 
| 325 | 
            +
                        dict = json.load(f)
         | 
| 326 | 
            +
                except Exception:
         | 
| 327 | 
            +
                    pass
         | 
| 328 | 
            +
                return dict
         | 
| 329 | 
            +
             | 
| 330 | 
            +
             | 
| 331 | 
            +
            model_prompt_dict = load_model_prompt_dict()
         | 
| 332 | 
            +
             | 
| 333 | 
            +
             | 
| 334 | 
            +
            def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
         | 
| 335 | 
            +
                if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
         | 
| 336 | 
            +
                prompts = to_list(prompt)
         | 
| 337 | 
            +
                neg_prompts = to_list(neg_prompt)
         | 
| 338 | 
            +
                prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
         | 
| 339 | 
            +
                neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
         | 
| 340 | 
            +
                last_empty_p = [""] if not prompts and type != "None" else []
         | 
| 341 | 
            +
                last_empty_np = [""] if not neg_prompts and type != "None" else []
         | 
| 342 | 
            +
                ps = []
         | 
| 343 | 
            +
                nps = []
         | 
| 344 | 
            +
                if model_name in model_prompt_dict.keys(): 
         | 
| 345 | 
            +
                    ps = to_list(model_prompt_dict[model_name]["prompt"])
         | 
| 346 | 
            +
                    nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
         | 
| 347 | 
            +
                else:
         | 
| 348 | 
            +
                    ps = default_ps
         | 
| 349 | 
            +
                    nps = default_nps
         | 
| 350 | 
            +
                prompts = prompts + ps
         | 
| 351 | 
            +
                neg_prompts = neg_prompts + nps
         | 
| 352 | 
            +
                prompt = ", ".join(list_uniq(prompts) + last_empty_p)
         | 
| 353 | 
            +
                neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
         | 
| 354 | 
            +
                return prompt, neg_prompt
         | 
| 355 | 
            +
             | 
| 356 | 
            +
             | 
| 357 | 
            +
            tag_group_dict = load_dict_from_csv('tag_group.csv')
         | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
            def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
         | 
| 361 | 
            +
                def is_dressed(tag):
         | 
| 362 | 
            +
                    import re
         | 
| 363 | 
            +
                    p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
         | 
| 364 | 
            +
                    return p.search(tag)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def is_background(tag):
         | 
| 367 | 
            +
                    import re
         | 
| 368 | 
            +
                    p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
         | 
| 369 | 
            +
                    return p.search(tag)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                un_tags = ['solo']
         | 
| 372 | 
            +
                group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
         | 
| 373 | 
            +
                keep_group_dict = {
         | 
| 374 | 
            +
                    "body": ['groups', 'body_parts'],
         | 
| 375 | 
            +
                    "dress": ['groups', 'body_parts', 'attire'],
         | 
| 376 | 
            +
                    "all": group_list,
         | 
| 377 | 
            +
                }
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                def is_necessary(tag, keep_tags, group_dict):
         | 
| 380 | 
            +
                    if keep_tags == "all":
         | 
| 381 | 
            +
                        return True
         | 
| 382 | 
            +
                    elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
         | 
| 383 | 
            +
                        return False
         | 
| 384 | 
            +
                    elif keep_tags == "body" and is_dressed(tag):
         | 
| 385 | 
            +
                        return False
         | 
| 386 | 
            +
                    elif is_background(tag):
         | 
| 387 | 
            +
                        return False
         | 
| 388 | 
            +
                    else:
         | 
| 389 | 
            +
                        return True
         | 
| 390 | 
            +
                
         | 
| 391 | 
            +
                if keep_tags == "all": return input_prompt
         | 
| 392 | 
            +
                keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
         | 
| 393 | 
            +
                explicit_group = list(set(group_list) ^ set(keep_group))
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                tags = input_prompt.split(",") if input_prompt else []
         | 
| 396 | 
            +
                people_tags: list[str] = []
         | 
| 397 | 
            +
                other_tags: list[str] = []
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                group_dict = tag_group_dict
         | 
| 400 | 
            +
                for tag in tags:
         | 
| 401 | 
            +
                    tag = replace_underline(tag)
         | 
| 402 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 403 | 
            +
                        people_tags.append(tag)
         | 
| 404 | 
            +
                    elif is_necessary(tag, keep_tags, group_dict):
         | 
| 405 | 
            +
                        other_tags.append(tag)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                output_prompt = ", ".join(people_tags + other_tags)
         | 
| 408 | 
            +
                
         | 
| 409 | 
            +
                return output_prompt
         | 
| 410 | 
            +
             | 
| 411 | 
            +
             | 
| 412 | 
            +
            def sort_taglist(tags: list[str]):
         | 
| 413 | 
            +
                if not tags: return []
         | 
| 414 | 
            +
                character_tags: list[str] = []
         | 
| 415 | 
            +
                series_tags: list[str] = []
         | 
| 416 | 
            +
                people_tags: list[str] = []
         | 
| 417 | 
            +
                group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
         | 
| 418 | 
            +
                group_tags = {}
         | 
| 419 | 
            +
                other_tags: list[str] = []
         | 
| 420 | 
            +
                rating_tags: list[str] = []
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                group_dict = tag_group_dict
         | 
| 423 | 
            +
                group_set = set(group_dict.keys())
         | 
| 424 | 
            +
                character_set = set(anime_series_dict.keys())
         | 
| 425 | 
            +
                series_set = set(anime_series_dict.values())
         | 
| 426 | 
            +
                rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                for tag in tags:
         | 
| 429 | 
            +
                    tag = replace_underline(tag)
         | 
| 430 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 431 | 
            +
                        people_tags.append(tag)
         | 
| 432 | 
            +
                    elif tag in rating_set:
         | 
| 433 | 
            +
                        rating_tags.append(tag)
         | 
| 434 | 
            +
                    elif tag in group_set:
         | 
| 435 | 
            +
                        elem = group_dict[tag]
         | 
| 436 | 
            +
                        group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
         | 
| 437 | 
            +
                    elif tag in character_set:
         | 
| 438 | 
            +
                        character_tags.append(tag)
         | 
| 439 | 
            +
                    elif tag in series_set:
         | 
| 440 | 
            +
                        series_tags.append(tag)
         | 
| 441 | 
            +
                    else:
         | 
| 442 | 
            +
                        other_tags.append(tag)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                output_group_tags: list[str] = []
         | 
| 445 | 
            +
                for k in group_list:
         | 
| 446 | 
            +
                    output_group_tags.extend(group_tags.get(k, []))
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                rating_tags = [rating_tags[0]] if rating_tags else []
         | 
| 449 | 
            +
                rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
         | 
| 452 | 
            +
                
         | 
| 453 | 
            +
                return output_tags
         | 
| 454 | 
            +
             | 
| 455 | 
            +
             | 
| 456 | 
            +
            def sort_tags(tags: str):
         | 
| 457 | 
            +
                if not tags: return ""
         | 
| 458 | 
            +
                taglist: list[str] = []
         | 
| 459 | 
            +
                for tag in tags.split(","):
         | 
| 460 | 
            +
                    taglist.append(tag.strip())
         | 
| 461 | 
            +
                taglist = list(filter(lambda x: x != "", taglist))
         | 
| 462 | 
            +
                return ", ".join(sort_taglist(taglist))
         | 
| 463 | 
            +
             | 
| 464 | 
            +
             | 
| 465 | 
            +
            def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
         | 
| 466 | 
            +
                results = {
         | 
| 467 | 
            +
                    k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
         | 
| 468 | 
            +
                }
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                rating = {}
         | 
| 471 | 
            +
                character = {}
         | 
| 472 | 
            +
                general = {}
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                for k, v in results.items():
         | 
| 475 | 
            +
                    if k.startswith("rating:"):
         | 
| 476 | 
            +
                        rating[k.replace("rating:", "")] = v
         | 
| 477 | 
            +
                        continue
         | 
| 478 | 
            +
                    elif k.startswith("character:"):
         | 
| 479 | 
            +
                        character[k.replace("character:", "")] = v
         | 
| 480 | 
            +
                        continue
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    general[k] = v
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                character = {k: v for k, v in character.items() if v >= character_threshold}
         | 
| 485 | 
            +
                general = {k: v for k, v in general.items() if v >= general_threshold}
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                return rating, character, general
         | 
| 488 | 
            +
             | 
| 489 | 
            +
             | 
| 490 | 
            +
            def gen_prompt(rating: list[str], character: list[str], general: list[str]):
         | 
| 491 | 
            +
                people_tags: list[str] = []
         | 
| 492 | 
            +
                other_tags: list[str] = []
         | 
| 493 | 
            +
                rating_tag = RATING_MAP[rating[0]]
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                for tag in general:
         | 
| 496 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 497 | 
            +
                        people_tags.append(tag)
         | 
| 498 | 
            +
                    else:
         | 
| 499 | 
            +
                        other_tags.append(tag)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                all_tags = people_tags + other_tags
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                return ", ".join(all_tags)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
             | 
| 506 | 
            +
            @spaces.GPU()
         | 
| 507 | 
            +
            def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
         | 
| 508 | 
            +
                inputs = wd_processor.preprocess(image, return_tensors="pt")
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
         | 
| 511 | 
            +
                logits = torch.sigmoid(outputs.logits[0])  # take the first logits
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                # get probabilities
         | 
| 514 | 
            +
                results = {
         | 
| 515 | 
            +
                    wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
         | 
| 516 | 
            +
                }
         | 
| 517 | 
            +
                # rating, character, general
         | 
| 518 | 
            +
                rating, character, general = postprocess_results(
         | 
| 519 | 
            +
                    results, general_threshold, character_threshold
         | 
| 520 | 
            +
                )
         | 
| 521 | 
            +
                prompt = gen_prompt(
         | 
| 522 | 
            +
                    list(rating.keys()), list(character.keys()), list(general.keys())
         | 
| 523 | 
            +
                )
         | 
| 524 | 
            +
                output_series_tag = ""
         | 
| 525 | 
            +
                output_series_list = character_list_to_series_list(character.keys())
         | 
| 526 | 
            +
                if output_series_list:
         | 
| 527 | 
            +
                    output_series_tag = output_series_list[0]
         | 
| 528 | 
            +
                else:
         | 
| 529 | 
            +
                    output_series_tag = ""
         | 
| 530 | 
            +
                return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
             | 
| 533 | 
            +
            def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
         | 
| 534 | 
            +
                                 character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
         | 
| 535 | 
            +
                if not "Use WD Tagger" in algo and len(algo) != 0:
         | 
| 536 | 
            +
                    return input_series, input_character, input_tags, gr.update(interactive=True)
         | 
| 537 | 
            +
                return predict_tags(image, general_threshold, character_threshold)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
             | 
| 540 | 
            +
            def compose_prompt_to_copy(character: str, series: str, general: str):
         | 
| 541 | 
            +
                characters = character.split(",") if character else []
         | 
| 542 | 
            +
                serieses = series.split(",") if series else []
         | 
| 543 | 
            +
                generals = general.split(",") if general else []
         | 
| 544 | 
            +
                tags = characters + serieses + generals
         | 
| 545 | 
            +
                cprompt = ",".join(tags) if tags else ""
         | 
| 546 | 
            +
                return cprompt
         | 
    	
        tagger/utils.py
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
         | 
| 6 | 
            +
                "ultra_wide",
         | 
| 7 | 
            +
                "wide",
         | 
| 8 | 
            +
                "square",
         | 
| 9 | 
            +
                "tall",
         | 
| 10 | 
            +
                "ultra_tall",
         | 
| 11 | 
            +
            ]
         | 
| 12 | 
            +
            V2_RATING_OPTIONS: list[RatingTag] = [
         | 
| 13 | 
            +
                "sfw",
         | 
| 14 | 
            +
                "general",
         | 
| 15 | 
            +
                "sensitive",
         | 
| 16 | 
            +
                "nsfw",
         | 
| 17 | 
            +
                "questionable",
         | 
| 18 | 
            +
                "explicit",
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
            V2_LENGTH_OPTIONS: list[LengthTag] = [
         | 
| 21 | 
            +
                "very_short",
         | 
| 22 | 
            +
                "short",
         | 
| 23 | 
            +
                "medium",
         | 
| 24 | 
            +
                "long",
         | 
| 25 | 
            +
                "very_long",
         | 
| 26 | 
            +
            ]
         | 
| 27 | 
            +
            V2_IDENTITY_OPTIONS: list[IdentityTag] = [
         | 
| 28 | 
            +
                "none",
         | 
| 29 | 
            +
                "lax",
         | 
| 30 | 
            +
                "strict",
         | 
| 31 | 
            +
            ]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
         | 
| 35 | 
            +
            def gradio_copy_text(_text: None):
         | 
| 36 | 
            +
                gr.Info("Copied!")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            COPY_ACTION_JS = """\
         | 
| 40 | 
            +
            (inputs, _outputs) => {
         | 
| 41 | 
            +
              // inputs is the string value of the input_text
         | 
| 42 | 
            +
              if (inputs.trim() !== "") {
         | 
| 43 | 
            +
                navigator.clipboard.writeText(inputs);
         | 
| 44 | 
            +
              }
         | 
| 45 | 
            +
            }"""
         | 
    	
        tagger/v2.py
    ADDED
    
    | @@ -0,0 +1,260 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import time
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from typing import Callable
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from dartrs.v2 import (
         | 
| 7 | 
            +
                V2Model,
         | 
| 8 | 
            +
                MixtralModel,
         | 
| 9 | 
            +
                MistralModel,
         | 
| 10 | 
            +
                compose_prompt,
         | 
| 11 | 
            +
                LengthTag,
         | 
| 12 | 
            +
                AspectRatioTag,
         | 
| 13 | 
            +
                RatingTag,
         | 
| 14 | 
            +
                IdentityTag,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from dartrs.dartrs import DartTokenizer
         | 
| 17 | 
            +
            from dartrs.utils import get_generation_config
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            import gradio as gr
         | 
| 21 | 
            +
            from gradio.components import Component
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            try:
         | 
| 25 | 
            +
                from output import UpsamplingOutput
         | 
| 26 | 
            +
            except:
         | 
| 27 | 
            +
                from .output import UpsamplingOutput
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            V2_ALL_MODELS = {
         | 
| 31 | 
            +
                "dart-v2-moe-sft": {
         | 
| 32 | 
            +
                    "repo": "p1atdev/dart-v2-moe-sft",
         | 
| 33 | 
            +
                    "type": "sft",
         | 
| 34 | 
            +
                    "class": MixtralModel,
         | 
| 35 | 
            +
                },
         | 
| 36 | 
            +
                "dart-v2-sft": {
         | 
| 37 | 
            +
                    "repo": "p1atdev/dart-v2-sft",
         | 
| 38 | 
            +
                    "type": "sft",
         | 
| 39 | 
            +
                    "class": MistralModel,
         | 
| 40 | 
            +
                },
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def prepare_models(model_config: dict):
         | 
| 45 | 
            +
                model_name = model_config["repo"]
         | 
| 46 | 
            +
                tokenizer = DartTokenizer.from_pretrained(model_name)
         | 
| 47 | 
            +
                model = model_config["class"].from_pretrained(model_name)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return {
         | 
| 50 | 
            +
                    "tokenizer": tokenizer,
         | 
| 51 | 
            +
                    "model": model,
         | 
| 52 | 
            +
                }
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def normalize_tags(tokenizer: DartTokenizer, tags: str):
         | 
| 56 | 
            +
                """Just remove unk tokens."""
         | 
| 57 | 
            +
                return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            @torch.no_grad()
         | 
| 61 | 
            +
            def generate_tags(
         | 
| 62 | 
            +
                model: V2Model,
         | 
| 63 | 
            +
                tokenizer: DartTokenizer,
         | 
| 64 | 
            +
                prompt: str,
         | 
| 65 | 
            +
                ban_token_ids: list[int],
         | 
| 66 | 
            +
            ):
         | 
| 67 | 
            +
                output = model.generate(
         | 
| 68 | 
            +
                    get_generation_config(
         | 
| 69 | 
            +
                        prompt,
         | 
| 70 | 
            +
                        tokenizer=tokenizer,
         | 
| 71 | 
            +
                        temperature=1,
         | 
| 72 | 
            +
                        top_p=0.9,
         | 
| 73 | 
            +
                        top_k=100,
         | 
| 74 | 
            +
                        max_new_tokens=256,
         | 
| 75 | 
            +
                        ban_token_ids=ban_token_ids,
         | 
| 76 | 
            +
                    ),
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                return output
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
         | 
| 83 | 
            +
                return (
         | 
| 84 | 
            +
                    [f"1{noun}"]
         | 
| 85 | 
            +
                    + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
         | 
| 86 | 
            +
                    + [f"{maximum+1}+{noun}s"]
         | 
| 87 | 
            +
                )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            PEOPLE_TAGS = (
         | 
| 91 | 
            +
                _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
         | 
| 92 | 
            +
            )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def gen_prompt_text(output: UpsamplingOutput):
         | 
| 96 | 
            +
                # separate people tags (e.g. 1girl)
         | 
| 97 | 
            +
                people_tags = []
         | 
| 98 | 
            +
                other_general_tags = []
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                for tag in output.general_tags.split(","):
         | 
| 101 | 
            +
                    tag = tag.strip()
         | 
| 102 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 103 | 
            +
                        people_tags.append(tag)
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        other_general_tags.append(tag)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                return ", ".join(
         | 
| 108 | 
            +
                    [
         | 
| 109 | 
            +
                        part.strip()
         | 
| 110 | 
            +
                        for part in [
         | 
| 111 | 
            +
                            *people_tags,
         | 
| 112 | 
            +
                            output.character_tags,
         | 
| 113 | 
            +
                            output.copyright_tags,
         | 
| 114 | 
            +
                            *other_general_tags,
         | 
| 115 | 
            +
                            output.upsampled_tags,
         | 
| 116 | 
            +
                            output.rating_tag,
         | 
| 117 | 
            +
                        ]
         | 
| 118 | 
            +
                        if part.strip() != ""
         | 
| 119 | 
            +
                    ]
         | 
| 120 | 
            +
                )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def elapsed_time_format(elapsed_time: float) -> str:
         | 
| 124 | 
            +
                return f"Elapsed: {elapsed_time:.2f} seconds"
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def parse_upsampling_output(
         | 
| 128 | 
            +
                upsampler: Callable[..., UpsamplingOutput],
         | 
| 129 | 
            +
            ):
         | 
| 130 | 
            +
                def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
         | 
| 131 | 
            +
                    output = upsampler(*args)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    return (
         | 
| 134 | 
            +
                        gen_prompt_text(output),
         | 
| 135 | 
            +
                        elapsed_time_format(output.elapsed_time),
         | 
| 136 | 
            +
                        gr.update(interactive=True),
         | 
| 137 | 
            +
                        gr.update(interactive=True),
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                return _parse_upsampling_output
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            class V2UI:
         | 
| 144 | 
            +
                model_name: str | None = None
         | 
| 145 | 
            +
                model: V2Model
         | 
| 146 | 
            +
                tokenizer: DartTokenizer
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                input_components: list[Component] = []
         | 
| 149 | 
            +
                generate_btn: gr.Button
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def on_generate(
         | 
| 152 | 
            +
                    self,
         | 
| 153 | 
            +
                    model_name: str,
         | 
| 154 | 
            +
                    copyright_tags: str,
         | 
| 155 | 
            +
                    character_tags: str,
         | 
| 156 | 
            +
                    general_tags: str,
         | 
| 157 | 
            +
                    rating_tag: RatingTag,
         | 
| 158 | 
            +
                    aspect_ratio_tag: AspectRatioTag,
         | 
| 159 | 
            +
                    length_tag: LengthTag,
         | 
| 160 | 
            +
                    identity_tag: IdentityTag,
         | 
| 161 | 
            +
                    ban_tags: str,
         | 
| 162 | 
            +
                    *args,
         | 
| 163 | 
            +
                ) -> UpsamplingOutput:
         | 
| 164 | 
            +
                    if self.model_name is None or self.model_name != model_name:
         | 
| 165 | 
            +
                        models = prepare_models(V2_ALL_MODELS[model_name])
         | 
| 166 | 
            +
                        self.model = models["model"]
         | 
| 167 | 
            +
                        self.tokenizer = models["tokenizer"]
         | 
| 168 | 
            +
                        self.model_name = model_name
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # normalize tags
         | 
| 171 | 
            +
                    # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
         | 
| 172 | 
            +
                    # character_tags = normalize_tags(self.tokenizer, character_tags)
         | 
| 173 | 
            +
                    # general_tags = normalize_tags(self.tokenizer, general_tags)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    ban_token_ids = self.tokenizer.encode(ban_tags.strip())
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    prompt = compose_prompt(
         | 
| 178 | 
            +
                        prompt=general_tags,
         | 
| 179 | 
            +
                        copyright=copyright_tags,
         | 
| 180 | 
            +
                        character=character_tags,
         | 
| 181 | 
            +
                        rating=rating_tag,
         | 
| 182 | 
            +
                        aspect_ratio=aspect_ratio_tag,
         | 
| 183 | 
            +
                        length=length_tag,
         | 
| 184 | 
            +
                        identity=identity_tag,
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    start = time.time()
         | 
| 188 | 
            +
                    upsampled_tags = generate_tags(
         | 
| 189 | 
            +
                        self.model,
         | 
| 190 | 
            +
                        self.tokenizer,
         | 
| 191 | 
            +
                        prompt,
         | 
| 192 | 
            +
                        ban_token_ids,
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
                    elapsed_time = time.time() - start
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    return UpsamplingOutput(
         | 
| 197 | 
            +
                        upsampled_tags=upsampled_tags,
         | 
| 198 | 
            +
                        copyright_tags=copyright_tags,
         | 
| 199 | 
            +
                        character_tags=character_tags,
         | 
| 200 | 
            +
                        general_tags=general_tags,
         | 
| 201 | 
            +
                        rating_tag=rating_tag,
         | 
| 202 | 
            +
                        aspect_ratio_tag=aspect_ratio_tag,
         | 
| 203 | 
            +
                        length_tag=length_tag,
         | 
| 204 | 
            +
                        identity_tag=identity_tag,
         | 
| 205 | 
            +
                        elapsed_time=elapsed_time,
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            def parse_upsampling_output_simple(upsampler: UpsamplingOutput):
         | 
| 210 | 
            +
                return gen_prompt_text(upsampler)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            v2 = V2UI()
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "",
         | 
| 217 | 
            +
                                      general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square",
         | 
| 218 | 
            +
                                        length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"):
         | 
| 219 | 
            +
                raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags,
         | 
| 220 | 
            +
                                                                            rating, aspect_ratio, length, identity, ban_tags))
         | 
| 221 | 
            +
                return raw_prompt
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            def load_dict_from_csv(filename):
         | 
| 225 | 
            +
                dict = {}
         | 
| 226 | 
            +
                if not Path(filename).exists():
         | 
| 227 | 
            +
                    if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
         | 
| 228 | 
            +
                    else: return dict
         | 
| 229 | 
            +
                try:
         | 
| 230 | 
            +
                    with open(filename, 'r', encoding="utf-8") as f:
         | 
| 231 | 
            +
                        lines = f.readlines()
         | 
| 232 | 
            +
                except Exception:
         | 
| 233 | 
            +
                    print(f"Failed to open dictionary file: {filename}")
         | 
| 234 | 
            +
                    return dict
         | 
| 235 | 
            +
                for line in lines:
         | 
| 236 | 
            +
                    parts = line.strip().split(',')
         | 
| 237 | 
            +
                    dict[parts[0]] = parts[1]
         | 
| 238 | 
            +
                return dict
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            anime_series_dict = load_dict_from_csv('character_series_dict.csv')
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            def select_random_character(series: str, character: str):
         | 
| 245 | 
            +
                from random import seed, randrange
         | 
| 246 | 
            +
                seed()
         | 
| 247 | 
            +
                character_list = list(anime_series_dict.keys())
         | 
| 248 | 
            +
                character = character_list[randrange(len(character_list) - 1)]
         | 
| 249 | 
            +
                series = anime_series_dict.get(character.split(",")[0].strip(), "")
         | 
| 250 | 
            +
                return series, character
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
            def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw",
         | 
| 254 | 
            +
                                  aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax",
         | 
| 255 | 
            +
                                  ban_tags: str = "censored", model: str = "dart-v2-moe-sft"):
         | 
| 256 | 
            +
                if copyright == "" and character == "":
         | 
| 257 | 
            +
                    copyright, character = select_random_character("", "")
         | 
| 258 | 
            +
                raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating,
         | 
| 259 | 
            +
                                                   aspect_ratio, length, identity, ban_tags)
         | 
| 260 | 
            +
                return raw_prompt, copyright, character
         | 
