Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Commit 
							
							·
						
						6444ed9
	
0
								Parent(s):
							
							
init
Browse files- .gitattributes +62 -0
- .gitignore +178 -0
- README.md +11 -0
- app/gpt4_o/brushedit_all_in_one_pipeline.py +80 -0
- app/gpt4_o/brushedit_app.py +914 -0
- app/gpt4_o/instructions.py +106 -0
- app/gpt4_o/requirements.txt +18 -0
- app/gpt4_o/run_app.sh +5 -0
- app/gpt4_o/vlm_pipeline.py +138 -0
- app/utils/utils.py +197 -0
- assets/hedgehog_rm_fg/hedgehog.png +3 -0
- assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png +3 -0
- assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png +3 -0
- assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png +3 -0
- assets/hedgehog_rm_fg/prompt.txt +1 -0
- assets/hedgehog_rp_bg/hedgehog.png +3 -0
- assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png +3 -0
- assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png +3 -0
- assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png +3 -0
- assets/hedgehog_rp_bg/prompt.txt +1 -0
- assets/hedgehog_rp_fg/hedgehog.png +3 -0
- assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png +3 -0
- assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png +3 -0
- assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png +3 -0
- assets/hedgehog_rp_fg/prompt.txt +1 -0
- assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png +3 -0
- assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png +3 -0
- assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png +3 -0
- assets/mona_lisa/mona_lisa.png +3 -0
- assets/mona_lisa/prompt.txt +1 -0
- assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png +3 -0
- assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png +3 -0
- assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png +3 -0
- assets/sunflower_girl/prompt.txt +1 -0
- assets/sunflower_girl/sunflower_girl.png +3 -0
- requirements.txt +20 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            *.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            *.jpeg filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            *.webp filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            *.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            *.bmp filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            *.tiff filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            assets/hedgehog_rm_fg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png filter=lfs diff=lfs merge=lfs -text
         | 
| 47 | 
            +
            assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png filter=lfs diff=lfs merge=lfs -text
         | 
| 48 | 
            +
            assets/hedgehog_rp_bg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
         | 
| 49 | 
            +
            assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png filter=lfs diff=lfs merge=lfs -text
         | 
| 50 | 
            +
            assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png filter=lfs diff=lfs merge=lfs -text
         | 
| 51 | 
            +
            assets/hedgehog_rp_fg/hedgehog.png filter=lfs diff=lfs merge=lfs -text
         | 
| 52 | 
            +
            assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png filter=lfs diff=lfs merge=lfs -text
         | 
| 53 | 
            +
            assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png filter=lfs diff=lfs merge=lfs -text
         | 
| 54 | 
            +
            assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png filter=lfs diff=lfs merge=lfs -text
         | 
| 55 | 
            +
            assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png filter=lfs diff=lfs merge=lfs -text
         | 
| 56 | 
            +
            assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png filter=lfs diff=lfs merge=lfs -text
         | 
| 57 | 
            +
            assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png filter=lfs diff=lfs merge=lfs -text
         | 
| 58 | 
            +
            assets/mona_lisa/mona_lisa.png filter=lfs diff=lfs merge=lfs -text
         | 
| 59 | 
            +
            assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png filter=lfs diff=lfs merge=lfs -text
         | 
| 60 | 
            +
            assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png filter=lfs diff=lfs merge=lfs -text
         | 
| 61 | 
            +
            assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png filter=lfs diff=lfs merge=lfs -text
         | 
| 62 | 
            +
            assets/sunflower_girl/sunflower_girl.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,178 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Initially taken from GitHub's Python gitignore file
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 4 | 
            +
            __pycache__/
         | 
| 5 | 
            +
            *.py[cod]
         | 
| 6 | 
            +
            *$py.class
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # C extensions
         | 
| 9 | 
            +
            *.so
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # tests and logs
         | 
| 12 | 
            +
            tests/fixtures/cached_*_text.txt
         | 
| 13 | 
            +
            logs/
         | 
| 14 | 
            +
            lightning_logs/
         | 
| 15 | 
            +
            lang_code_data/
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Distribution / packaging
         | 
| 18 | 
            +
            .Python
         | 
| 19 | 
            +
            build/
         | 
| 20 | 
            +
            develop-eggs/
         | 
| 21 | 
            +
            dist/
         | 
| 22 | 
            +
            downloads/
         | 
| 23 | 
            +
            eggs/
         | 
| 24 | 
            +
            .eggs/
         | 
| 25 | 
            +
            lib/
         | 
| 26 | 
            +
            lib64/
         | 
| 27 | 
            +
            parts/
         | 
| 28 | 
            +
            sdist/
         | 
| 29 | 
            +
            var/
         | 
| 30 | 
            +
            wheels/
         | 
| 31 | 
            +
            *.egg-info/
         | 
| 32 | 
            +
            .installed.cfg
         | 
| 33 | 
            +
            *.egg
         | 
| 34 | 
            +
            MANIFEST
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # PyInstaller
         | 
| 37 | 
            +
            #  Usually these files are written by a Python script from a template
         | 
| 38 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 39 | 
            +
            *.manifest
         | 
| 40 | 
            +
            *.spec
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            # Installer logs
         | 
| 43 | 
            +
            pip-log.txt
         | 
| 44 | 
            +
            pip-delete-this-directory.txt
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # Unit test / coverage reports
         | 
| 47 | 
            +
            htmlcov/
         | 
| 48 | 
            +
            .tox/
         | 
| 49 | 
            +
            .nox/
         | 
| 50 | 
            +
            .coverage
         | 
| 51 | 
            +
            .coverage.*
         | 
| 52 | 
            +
            .cache
         | 
| 53 | 
            +
            nosetests.xml
         | 
| 54 | 
            +
            coverage.xml
         | 
| 55 | 
            +
            *.cover
         | 
| 56 | 
            +
            .hypothesis/
         | 
| 57 | 
            +
            .pytest_cache/
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            # Translations
         | 
| 60 | 
            +
            *.mo
         | 
| 61 | 
            +
            *.pot
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            # Django stuff:
         | 
| 64 | 
            +
            *.log
         | 
| 65 | 
            +
            local_settings.py
         | 
| 66 | 
            +
            db.sqlite3
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Flask stuff:
         | 
| 69 | 
            +
            instance/
         | 
| 70 | 
            +
            .webassets-cache
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            # Scrapy stuff:
         | 
| 73 | 
            +
            .scrapy
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            # Sphinx documentation
         | 
| 76 | 
            +
            docs/_build/
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # PyBuilder
         | 
| 79 | 
            +
            target/
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            # Jupyter Notebook
         | 
| 82 | 
            +
            .ipynb_checkpoints
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            # IPython
         | 
| 85 | 
            +
            profile_default/
         | 
| 86 | 
            +
            ipython_config.py
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # pyenv
         | 
| 89 | 
            +
            .python-version
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            # celery beat schedule file
         | 
| 92 | 
            +
            celerybeat-schedule
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            # SageMath parsed files
         | 
| 95 | 
            +
            *.sage.py
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # Environments
         | 
| 98 | 
            +
            .env
         | 
| 99 | 
            +
            .venv
         | 
| 100 | 
            +
            env/
         | 
| 101 | 
            +
            venv/
         | 
| 102 | 
            +
            ENV/
         | 
| 103 | 
            +
            env.bak/
         | 
| 104 | 
            +
            venv.bak/
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            # Spyder project settings
         | 
| 107 | 
            +
            .spyderproject
         | 
| 108 | 
            +
            .spyproject
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            # Rope project settings
         | 
| 111 | 
            +
            .ropeproject
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            # mkdocs documentation
         | 
| 114 | 
            +
            /site
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            # mypy
         | 
| 117 | 
            +
            .mypy_cache/
         | 
| 118 | 
            +
            .dmypy.json
         | 
| 119 | 
            +
            dmypy.json
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            # Pyre type checker
         | 
| 122 | 
            +
            .pyre/
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            # vscode
         | 
| 125 | 
            +
            .vs
         | 
| 126 | 
            +
            .vscode
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            # Pycharm
         | 
| 129 | 
            +
            .idea
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            # TF code
         | 
| 132 | 
            +
            tensorflow_code
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            # Models
         | 
| 135 | 
            +
            proc_data
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            # examples
         | 
| 138 | 
            +
            runs
         | 
| 139 | 
            +
            /runs_old
         | 
| 140 | 
            +
            /wandb
         | 
| 141 | 
            +
            /examples/runs
         | 
| 142 | 
            +
            /examples/**/*.args
         | 
| 143 | 
            +
            /examples/rag/sweep
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            # data
         | 
| 146 | 
            +
            /data
         | 
| 147 | 
            +
            serialization_dir
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            # emacs
         | 
| 150 | 
            +
            *.*~
         | 
| 151 | 
            +
            debug.env
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            # vim
         | 
| 154 | 
            +
            .*.swp
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            # ctags
         | 
| 157 | 
            +
            tags
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            # pre-commit
         | 
| 160 | 
            +
            .pre-commit*
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            # .lock
         | 
| 163 | 
            +
            *.lock
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            # DS_Store (MacOS)
         | 
| 166 | 
            +
            .DS_Store
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            # RL pipelines may produce mp4 outputs
         | 
| 169 | 
            +
            *.mp4
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            # dependencies
         | 
| 172 | 
            +
            /transformers
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            # ruff
         | 
| 175 | 
            +
            .ruff_cache
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            # wandb
         | 
| 178 | 
            +
            wandb
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: BrushEdit
         | 
| 3 | 
            +
            emoji: 🤠
         | 
| 4 | 
            +
            colorFrom: indigo
         | 
| 5 | 
            +
            colorTo: gray
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.38.1
         | 
| 8 | 
            +
            app_file: app/gpt4_o/brushedit_app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            python_version: 3.1
         | 
| 11 | 
            +
            ---
         | 
    	
        app/gpt4_o/brushedit_all_in_one_pipeline.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image, ImageEnhance
         | 
| 2 | 
            +
            from diffusers.image_processor  import VaeImageProcessor
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def BrushEdit_Pipeline(pipe, 
         | 
| 10 | 
            +
                                prompts,
         | 
| 11 | 
            +
                                mask_np,
         | 
| 12 | 
            +
                                original_image, 
         | 
| 13 | 
            +
                                generator,
         | 
| 14 | 
            +
                                num_inference_steps,
         | 
| 15 | 
            +
                                guidance_scale,
         | 
| 16 | 
            +
                                control_strength,
         | 
| 17 | 
            +
                                negative_prompt,
         | 
| 18 | 
            +
                                num_samples,
         | 
| 19 | 
            +
                                blending):
         | 
| 20 | 
            +
                if mask_np.ndim != 3:
         | 
| 21 | 
            +
                    mask_np = mask_np[:, :, np.newaxis]
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                mask_np = mask_np / 255
         | 
| 24 | 
            +
                height, width = mask_np.shape[0], mask_np.shape[1]
         | 
| 25 | 
            +
                # back/foreground
         | 
| 26 | 
            +
                # if mask_np[94:547,94:546].sum() < mask_np.sum() - mask_np[94:547,94:546].sum() and mask_np[0,:].sum()>0 and mask_np[-1,:].sum()>0 and mask_np[:,0].sum()>0 and mask_np[:,-1].sum()>0  and mask_np[1,:].sum()>0 and mask_np[-2,:].sum()>0 and mask_np[:,1].sum()>0 and mask_np[:,-2].sum()>0 :
         | 
| 27 | 
            +
                #     mask_np = 1 - mask_np
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                ## resize the mask and original image to the same size which is divisible by vae_scale_factor
         | 
| 30 | 
            +
                image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
         | 
| 31 | 
            +
                height_new, width_new = image_processor.get_default_height_width(original_image, height, width)
         | 
| 32 | 
            +
                mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis]
         | 
| 33 | 
            +
                mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
         | 
| 34 | 
            +
                mask_blurred = mask_blurred[:, :, np.newaxis]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                original_image = cv2.resize(original_image, (width_new, height_new))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                init_image = original_image * (1 - mask_np)
         | 
| 39 | 
            +
                init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
         | 
| 40 | 
            +
                mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                brushnet_conditioning_scale = float(control_strength)
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                images = pipe(
         | 
| 45 | 
            +
                    [prompts] * num_samples, 
         | 
| 46 | 
            +
                    init_image, 
         | 
| 47 | 
            +
                    mask_image, 
         | 
| 48 | 
            +
                    num_inference_steps=num_inference_steps, 
         | 
| 49 | 
            +
                    guidance_scale=guidance_scale,
         | 
| 50 | 
            +
                    generator=generator,
         | 
| 51 | 
            +
                    brushnet_conditioning_scale=brushnet_conditioning_scale,
         | 
| 52 | 
            +
                    negative_prompt=[negative_prompt]*num_samples,
         | 
| 53 | 
            +
                    height=height_new,
         | 
| 54 | 
            +
                    width=width_new,
         | 
| 55 | 
            +
                ).images
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                if blending:
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    mask_blurred = mask_blurred * 0.5 + 0.5
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    ## convert to vae shape format, must be divisible by 8
         | 
| 62 | 
            +
                    original_image_pil = Image.fromarray(original_image).convert("RGB")
         | 
| 63 | 
            +
                    init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze())
         | 
| 64 | 
            +
                    init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255
         | 
| 65 | 
            +
                    init_image_np = init_image_np.astype(np.uint8)
         | 
| 66 | 
            +
                    image_all = []
         | 
| 67 | 
            +
                    for image_i in images:
         | 
| 68 | 
            +
                        image_np = np.array(image_i)
         | 
| 69 | 
            +
                        ## blending
         | 
| 70 | 
            +
                        image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np
         | 
| 71 | 
            +
                        image_pasted = image_pasted.astype(np.uint8)
         | 
| 72 | 
            +
                        image = Image.fromarray(image_pasted)
         | 
| 73 | 
            +
                        image_all.append(image)
         | 
| 74 | 
            +
                else:
         | 
| 75 | 
            +
                    image_all = images
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
                return image_all, mask_image
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
    	
        app/gpt4_o/brushedit_app.py
    ADDED
    
    | @@ -0,0 +1,914 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ##!/usr/bin/python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            import os, random
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import gradio as gr
         | 
| 8 | 
            +
            import spaces
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
         | 
| 16 | 
            +
            from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
         | 
| 17 | 
            +
            from scipy.ndimage import binary_dilation, binary_erosion
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from app.gpt4_o.vlm_pipeline import (
         | 
| 20 | 
            +
                vlm_response_editing_type, 
         | 
| 21 | 
            +
                vlm_response_object_wait_for_edit, 
         | 
| 22 | 
            +
                vlm_response_mask, 
         | 
| 23 | 
            +
                vlm_response_prompt_after_apply_instruction
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from app.gpt4_o.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
         | 
| 26 | 
            +
            from app.utils.utils import load_grounding_dino_model
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            #### Description ####
         | 
| 30 | 
            +
            head = r"""
         | 
| 31 | 
            +
            <div style="text-align: center;">
         | 
| 32 | 
            +
                <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
         | 
| 33 | 
            +
                <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
         | 
| 34 | 
            +
                    <a href='https://tencentarc.github.io/BrushNet/'><img src='https://img.shields.io/badge/Project_Page-BrushNet-green' alt='Project Page'></a>
         | 
| 35 | 
            +
                    <a href='https://github.com/TencentARC/BrushNet/blob/main/InstructionGuidedEditing/CVPR2024workshop_technique_report.pdf'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
         | 
| 36 | 
            +
                    <a href='https://github.com/TencentARC/BrushNet'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
         | 
| 37 | 
            +
                    
         | 
| 38 | 
            +
                </div>
         | 
| 39 | 
            +
                </br>
         | 
| 40 | 
            +
            </div>
         | 
| 41 | 
            +
            """
         | 
| 42 | 
            +
            descriptions = r"""
         | 
| 43 | 
            +
            Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
         | 
| 44 | 
            +
            🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
         | 
| 45 | 
            +
            """
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            instructions = r"""
         | 
| 48 | 
            +
            Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            🛠️ <b>Fully automated instruction-based editing</b>:
         | 
| 51 | 
            +
            <ul>
         | 
| 52 | 
            +
                <li> ⭐️ <b>step1:</b> Upload or select one image from Example. </li>
         | 
| 53 | 
            +
                <li> ⭐️ <b>step2:</b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
         | 
| 54 | 
            +
                <li> ⭐️ <b>step3:</b> Click <b>Run</b> button to automatic edit image.</li>
         | 
| 55 | 
            +
            </ul>
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            🛠️ <b>Interactive instruction-based editing</b>:
         | 
| 58 | 
            +
            <ul>
         | 
| 59 | 
            +
                <li> ⭐️ <b>step1:</b> Upload or select one image from Example. </li>
         | 
| 60 | 
            +
                <li> ⭐️ <b>step2:</b> Use a brush to outline the area you want to edit. </li>
         | 
| 61 | 
            +
                <li> ⭐️ <b>step3:</b> Input the instructions. </li>
         | 
| 62 | 
            +
                <li> ⭐️ <b>step4:</b> Click <b>Run</b> button to automatic edit image. </li>
         | 
| 63 | 
            +
            </ul>
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            💡 <b>Some tips</b>:
         | 
| 66 | 
            +
            <ul>    
         | 
| 67 | 
            +
                <li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
         | 
| 68 | 
            +
                <li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as  <b>randomization</b>,  <b>dilation</b>,  <b>erosion</b>, and  <b>movement</b>. </li>
         | 
| 69 | 
            +
                <li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
         | 
| 70 | 
            +
            </ul>
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            ☕️ Have fun!
         | 
| 73 | 
            +
                        """
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            # - - - - - examples  - - - - -  #
         | 
| 77 | 
            +
            EXAMPLES = [
         | 
| 78 | 
            +
                # [
         | 
| 79 | 
            +
                # {"background": Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA"),
         | 
| 80 | 
            +
                #  "layers": [Image.new("RGBA", (Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").width, Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").height), (0, 0, 0, 0))], 
         | 
| 81 | 
            +
                #  "composite": Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA")}, 
         | 
| 82 | 
            +
                # #  Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png").convert("RGBA"),
         | 
| 83 | 
            +
                #  "add a shining necklace", 
         | 
| 84 | 
            +
                # #  [Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.jpg")],
         | 
| 85 | 
            +
                # #  [Image.open("assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png")],
         | 
| 86 | 
            +
                # #  [Image.open("assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png")]
         | 
| 87 | 
            +
                # ],
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                [
         | 
| 90 | 
            +
                # load_image_from_url("https://github.com/liyaowei-stu/BrushEdit/blob/main/assets/mona_lisa/mona_lisa.png"),
         | 
| 91 | 
            +
                Image.open("assets/mona_lisa/mona_lisa.png").convert("RGBA"),
         | 
| 92 | 
            +
                 "add a shining necklace", 
         | 
| 93 | 
            +
                #  [Image.open("assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.jpg")],
         | 
| 94 | 
            +
                #  [Image.open("assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png")],
         | 
| 95 | 
            +
                #  [Image.open("assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png")]
         | 
| 96 | 
            +
                ],
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
            ]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            ## init VLM
         | 
| 105 | 
            +
            from openai import OpenAI
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
         | 
| 108 | 
            +
            os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
         | 
| 109 | 
            +
            vlm = OpenAI(base_url="http://v2.open.venus.oa.com/llmproxy")
         | 
| 110 | 
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            # download hf models
         | 
| 115 | 
            +
            base_model_path = hf_hub_download(
         | 
| 116 | 
            +
                repo_id="Yw22/BrushEdit",
         | 
| 117 | 
            +
                subfolder="base_model/realisticVisionV60B1_v51VAE",
         | 
| 118 | 
            +
                token=os.getenv("HF_TOKEN"),
         | 
| 119 | 
            +
            )
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            brushnet_path = hf_hub_download(
         | 
| 123 | 
            +
                repo_id="Yw22/BrushEdit",
         | 
| 124 | 
            +
                subfolder="brushnetX",
         | 
| 125 | 
            +
                token=os.getenv("HF_TOKEN"),
         | 
| 126 | 
            +
            )
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            sam_path = hf_hub_download(
         | 
| 129 | 
            +
                repo_id="Yw22/BrushEdit",
         | 
| 130 | 
            +
                subfolder="sam",
         | 
| 131 | 
            +
                filename="sam_vit_h_4b8939.pth",
         | 
| 132 | 
            +
                token=os.getenv("HF_TOKEN"),
         | 
| 133 | 
            +
            )
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            groundingdino_path = hf_hub_download(
         | 
| 136 | 
            +
                repo_id="Yw22/BrushEdit",
         | 
| 137 | 
            +
                subfolder="grounding_dino",
         | 
| 138 | 
            +
                filename="groundingdino_swint_ogc.pth",
         | 
| 139 | 
            +
                token=os.getenv("HF_TOKEN"),
         | 
| 140 | 
            +
            )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            # input brushnetX ckpt path
         | 
| 144 | 
            +
            brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
         | 
| 145 | 
            +
            pipe = StableDiffusionBrushNetPipeline.from_pretrained(
         | 
| 146 | 
            +
                    base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
         | 
| 147 | 
            +
                )
         | 
| 148 | 
            +
            # speed up diffusion process with faster scheduler and memory optimization
         | 
| 149 | 
            +
            pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
         | 
| 150 | 
            +
            # remove following line if xformers is not installed or when using Torch 2.0.
         | 
| 151 | 
            +
            # pipe.enable_xformers_memory_efficient_attention()
         | 
| 152 | 
            +
            pipe.enable_model_cpu_offload()
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            ## init SAM
         | 
| 156 | 
            +
            sam = build_sam(checkpoint=sam_path)
         | 
| 157 | 
            +
            sam.to(device=device)
         | 
| 158 | 
            +
            sam_predictor = SamPredictor(sam)
         | 
| 159 | 
            +
            sam_automask_generator = SamAutomaticMaskGenerator(sam)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            ## init groundingdino_model
         | 
| 162 | 
            +
            config_file = 'third_party/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
         | 
| 163 | 
            +
            groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            ## Ordinary function
         | 
| 166 | 
            +
            def crop_and_resize(image: Image.Image, 
         | 
| 167 | 
            +
                                target_width: int, 
         | 
| 168 | 
            +
                                target_height: int) -> Image.Image:
         | 
| 169 | 
            +
                """
         | 
| 170 | 
            +
                Crops and resizes an image while preserving the aspect ratio.
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                Args:
         | 
| 173 | 
            +
                    image (Image.Image): Input PIL image to be cropped and resized.
         | 
| 174 | 
            +
                    target_width (int): Target width of the output image.
         | 
| 175 | 
            +
                    target_height (int): Target height of the output image.
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                Returns:
         | 
| 178 | 
            +
                    Image.Image: Cropped and resized image.
         | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
                # Original dimensions
         | 
| 181 | 
            +
                original_width, original_height = image.size
         | 
| 182 | 
            +
                original_aspect = original_width / original_height
         | 
| 183 | 
            +
                target_aspect = target_width / target_height
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                # Calculate crop box to maintain aspect ratio
         | 
| 186 | 
            +
                if original_aspect > target_aspect:
         | 
| 187 | 
            +
                    # Crop horizontally
         | 
| 188 | 
            +
                    new_width = int(original_height * target_aspect)
         | 
| 189 | 
            +
                    new_height = original_height
         | 
| 190 | 
            +
                    left = (original_width - new_width) / 2
         | 
| 191 | 
            +
                    top = 0
         | 
| 192 | 
            +
                    right = left + new_width
         | 
| 193 | 
            +
                    bottom = original_height
         | 
| 194 | 
            +
                else:
         | 
| 195 | 
            +
                    # Crop vertically
         | 
| 196 | 
            +
                    new_width = original_width
         | 
| 197 | 
            +
                    new_height = int(original_width / target_aspect)
         | 
| 198 | 
            +
                    left = 0
         | 
| 199 | 
            +
                    top = (original_height - new_height) / 2
         | 
| 200 | 
            +
                    right = original_width
         | 
| 201 | 
            +
                    bottom = top + new_height
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                # Crop and resize
         | 
| 204 | 
            +
                cropped_image = image.crop((left, top, right, bottom))
         | 
| 205 | 
            +
                resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return resized_image
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def move_mask_func(mask, direction, units):
         | 
| 211 | 
            +
                binary_mask = mask.squeeze()>0
         | 
| 212 | 
            +
                rows, cols = binary_mask.shape
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                moved_mask = np.zeros_like(binary_mask, dtype=bool)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                if direction == 'down':
         | 
| 217 | 
            +
                    # move down
         | 
| 218 | 
            +
                    moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                elif direction == 'up':
         | 
| 221 | 
            +
                    # move up
         | 
| 222 | 
            +
                    moved_mask[:rows - units, :] = binary_mask[units:, :]
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                elif direction == 'right':
         | 
| 225 | 
            +
                    # move left
         | 
| 226 | 
            +
                    moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                elif direction == 'left':
         | 
| 229 | 
            +
                    # move right
         | 
| 230 | 
            +
                    moved_mask[:, :cols - units] = binary_mask[:, units:]
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                return moved_mask
         | 
| 233 | 
            +
             | 
| 234 | 
            +
             | 
| 235 | 
            +
            def random_mask_func(mask, dilation_type='square'):
         | 
| 236 | 
            +
                # Randomly select the size of dilation
         | 
| 237 | 
            +
                dilation_size = np.random.randint(20, 40)  # Randomly select the size of dilation
         | 
| 238 | 
            +
                binary_mask = mask.squeeze()>0
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                if dilation_type == 'square_dilation':
         | 
| 241 | 
            +
                    structure = np.ones((dilation_size, dilation_size), dtype=bool)
         | 
| 242 | 
            +
                    dilated_mask = binary_dilation(binary_mask, structure=structure)
         | 
| 243 | 
            +
                elif dilation_type == 'square_erosion':
         | 
| 244 | 
            +
                    structure = np.ones((dilation_size, dilation_size), dtype=bool)
         | 
| 245 | 
            +
                    dilated_mask = binary_erosion(binary_mask, structure=structure)
         | 
| 246 | 
            +
                elif dilation_type == 'bounding_box':
         | 
| 247 | 
            +
                    # find the most left top and left bottom point
         | 
| 248 | 
            +
                    rows, cols = np.where(binary_mask)
         | 
| 249 | 
            +
                    if len(rows) == 0 or len(cols) == 0:
         | 
| 250 | 
            +
                        return mask  # return original mask if no valid points
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    min_row = np.min(rows)
         | 
| 253 | 
            +
                    max_row = np.max(rows)
         | 
| 254 | 
            +
                    min_col = np.min(cols)
         | 
| 255 | 
            +
                    max_col = np.max(cols)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # create a bounding box
         | 
| 258 | 
            +
                    dilated_mask = np.zeros_like(binary_mask, dtype=bool)
         | 
| 259 | 
            +
                    dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                elif dilation_type == 'bounding_ellipse':
         | 
| 262 | 
            +
                    # find the most left top and left bottom point
         | 
| 263 | 
            +
                    rows, cols = np.where(binary_mask)
         | 
| 264 | 
            +
                    if len(rows) == 0 or len(cols) == 0:
         | 
| 265 | 
            +
                        return mask  # return original mask if no valid points
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    min_row = np.min(rows)
         | 
| 268 | 
            +
                    max_row = np.max(rows)
         | 
| 269 | 
            +
                    min_col = np.min(cols)
         | 
| 270 | 
            +
                    max_col = np.max(cols)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    # calculate the center and axis length of the ellipse
         | 
| 273 | 
            +
                    center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
         | 
| 274 | 
            +
                    a = (max_col - min_col) // 2  # half long axis
         | 
| 275 | 
            +
                    b = (max_row - min_row) // 2  # half short axis
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    # create a bounding ellipse
         | 
| 278 | 
            +
                    y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
         | 
| 279 | 
            +
                    ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
         | 
| 280 | 
            +
                    dilated_mask = np.zeros_like(binary_mask, dtype=bool)
         | 
| 281 | 
            +
                    dilated_mask[ellipse_mask] = True
         | 
| 282 | 
            +
                else:
         | 
| 283 | 
            +
                    raise ValueError("dilation_type must be 'square' or 'ellipse'")
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                # use binary dilation
         | 
| 286 | 
            +
                dilated_mask =  np.uint8(dilated_mask[:,:,np.newaxis]) * 255
         | 
| 287 | 
            +
                return dilated_mask
         | 
| 288 | 
            +
             | 
| 289 | 
            +
             | 
| 290 | 
            +
            ## Gradio component function
         | 
| 291 | 
            +
            @spaces.GPU(duration=180)
         | 
| 292 | 
            +
            def process(input_image, 
         | 
| 293 | 
            +
                original_image, 
         | 
| 294 | 
            +
                original_mask, 
         | 
| 295 | 
            +
                prompt, 
         | 
| 296 | 
            +
                negative_prompt, 
         | 
| 297 | 
            +
                control_strength, 
         | 
| 298 | 
            +
                seed, 
         | 
| 299 | 
            +
                randomize_seed, 
         | 
| 300 | 
            +
                guidance_scale, 
         | 
| 301 | 
            +
                num_inference_steps,
         | 
| 302 | 
            +
                num_samples,
         | 
| 303 | 
            +
                blending,
         | 
| 304 | 
            +
                category,
         | 
| 305 | 
            +
                target_prompt,
         | 
| 306 | 
            +
                resize_and_crop):
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                import ipdb; ipdb.set_trace()
         | 
| 309 | 
            +
                if original_image is None:
         | 
| 310 | 
            +
                    raise gr.Error('Please upload the input image')
         | 
| 311 | 
            +
                if prompt is None:
         | 
| 312 | 
            +
                    raise gr.Error("Please input your instructions, e.g., remove the xxx")
         | 
| 313 | 
            +
                
         | 
| 314 | 
            +
                
         | 
| 315 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 316 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 317 | 
            +
                if resize_and_crop:
         | 
| 318 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 319 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 320 | 
            +
                    original_image = np.array(original_image)
         | 
| 321 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                if input_mask.max() == 0:
         | 
| 324 | 
            +
                    original_mask = original_mask
         | 
| 325 | 
            +
                else:
         | 
| 326 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                # load example image
         | 
| 329 | 
            +
                # if isinstance(original_image, str):
         | 
| 330 | 
            +
                #     # image_name = image_examples[original_image][0]
         | 
| 331 | 
            +
                #     # original_image = cv2.imread(image_name)
         | 
| 332 | 
            +
                #     # original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
         | 
| 333 | 
            +
                #     original_image = input_image
         | 
| 334 | 
            +
                #     num_samples = 1
         | 
| 335 | 
            +
                #     blending = True
         | 
| 336 | 
            +
                
         | 
| 337 | 
            +
                if category is not None:
         | 
| 338 | 
            +
                    pass 
         | 
| 339 | 
            +
                else:
         | 
| 340 | 
            +
                    category = vlm_response_editing_type(vlm, original_image, prompt)
         | 
| 341 | 
            +
                
         | 
| 342 | 
            +
                
         | 
| 343 | 
            +
                if original_mask is not None:
         | 
| 344 | 
            +
                    original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
         | 
| 345 | 
            +
                else:
         | 
| 346 | 
            +
                    object_wait_for_edit = vlm_response_object_wait_for_edit(vlm, 
         | 
| 347 | 
            +
                                                                             category, 
         | 
| 348 | 
            +
                                                                             prompt)
         | 
| 349 | 
            +
                    original_mask = vlm_response_mask(vlm,
         | 
| 350 | 
            +
                                                      category, 
         | 
| 351 | 
            +
                                                      original_image, 
         | 
| 352 | 
            +
                                                      prompt, 
         | 
| 353 | 
            +
                                                      object_wait_for_edit, 
         | 
| 354 | 
            +
                                                      sam,
         | 
| 355 | 
            +
                                                      sam_predictor,
         | 
| 356 | 
            +
                                                      sam_automask_generator,
         | 
| 357 | 
            +
                                                      groundingdino_model,
         | 
| 358 | 
            +
                                                      )[:,:,None]
         | 
| 359 | 
            +
                
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                if len(target_prompt) <= 1:
         | 
| 362 | 
            +
                    prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(vlm, 
         | 
| 363 | 
            +
                                                                                             original_image,
         | 
| 364 | 
            +
                                                                                             prompt)
         | 
| 365 | 
            +
                else:
         | 
| 366 | 
            +
                    prompt_after_apply_instruction = target_prompt
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                generator = torch.Generator("cuda").manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
             | 
| 371 | 
            +
                
         | 
| 372 | 
            +
                image, mask_image = BrushEdit_Pipeline(pipe, 
         | 
| 373 | 
            +
                                                prompt_after_apply_instruction,
         | 
| 374 | 
            +
                                                original_mask,
         | 
| 375 | 
            +
                                                original_image,
         | 
| 376 | 
            +
                                                generator,
         | 
| 377 | 
            +
                                                num_inference_steps,
         | 
| 378 | 
            +
                                                guidance_scale,
         | 
| 379 | 
            +
                                                control_strength,
         | 
| 380 | 
            +
                                                negative_prompt,
         | 
| 381 | 
            +
                                                num_samples,
         | 
| 382 | 
            +
                                                blending)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                masked_image = original_image * (1 - (original_mask>0))
         | 
| 385 | 
            +
                masked_image = masked_image.astype(np.uint8)
         | 
| 386 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                import uuid
         | 
| 389 | 
            +
                uuid = str(uuid.uuid4())
         | 
| 390 | 
            +
                image[0].save(f"outputs/image_edit_{uuid}_0.png")
         | 
| 391 | 
            +
                image[1].save(f"outputs/image_edit_{uuid}_1.png")
         | 
| 392 | 
            +
                image[2].save(f"outputs/image_edit_{uuid}_2.png")
         | 
| 393 | 
            +
                image[3].save(f"outputs/image_edit_{uuid}_3.png")
         | 
| 394 | 
            +
                mask_image.save(f"outputs/mask_{uuid}.png")
         | 
| 395 | 
            +
                masked_image.save(f"outputs/masked_image_{uuid}.png")
         | 
| 396 | 
            +
                return image, [mask_image], [masked_image], ''
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            def generate_target_prompt(input_image, 
         | 
| 400 | 
            +
                                       original_image, 
         | 
| 401 | 
            +
                                       prompt):
         | 
| 402 | 
            +
                # load example image
         | 
| 403 | 
            +
                if isinstance(original_image, str):
         | 
| 404 | 
            +
                    original_image = input_image
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(vlm, 
         | 
| 407 | 
            +
                                                                        original_image,
         | 
| 408 | 
            +
                                                                        prompt)
         | 
| 409 | 
            +
                return prompt_after_apply_instruction
         | 
| 410 | 
            +
             | 
| 411 | 
            +
             | 
| 412 | 
            +
            def process_mask(input_image, 
         | 
| 413 | 
            +
                original_image, 
         | 
| 414 | 
            +
                prompt,
         | 
| 415 | 
            +
                resize_and_crop):
         | 
| 416 | 
            +
                if original_image is None:
         | 
| 417 | 
            +
                    raise gr.Error('Please upload the input image')
         | 
| 418 | 
            +
                if prompt is None:
         | 
| 419 | 
            +
                    raise gr.Error("Please input your instructions, e.g., remove the xxx")
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                ## load mask
         | 
| 422 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 423 | 
            +
                input_mask = np.array(alpha_mask)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                # load example image
         | 
| 426 | 
            +
                if isinstance(original_image, str):
         | 
| 427 | 
            +
                    original_image = input_image["background"]
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                if resize_and_crop:
         | 
| 430 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 431 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 432 | 
            +
                    original_image = np.array(original_image)
         | 
| 433 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 434 | 
            +
                
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                if input_mask.max() == 0:
         | 
| 437 | 
            +
                    category = vlm_response_editing_type(vlm, original_image, prompt)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    object_wait_for_edit = vlm_response_object_wait_for_edit(vlm, 
         | 
| 440 | 
            +
                                                                            category, 
         | 
| 441 | 
            +
                                                                            prompt)
         | 
| 442 | 
            +
                    # original mask: h,w,1 [0, 255]
         | 
| 443 | 
            +
                    original_mask = vlm_response_mask(
         | 
| 444 | 
            +
                        vlm,
         | 
| 445 | 
            +
                        category, 
         | 
| 446 | 
            +
                        original_image, 
         | 
| 447 | 
            +
                        prompt, 
         | 
| 448 | 
            +
                        object_wait_for_edit, 
         | 
| 449 | 
            +
                        sam,
         | 
| 450 | 
            +
                        sam_predictor,
         | 
| 451 | 
            +
                        sam_automask_generator,
         | 
| 452 | 
            +
                        groundingdino_model,
         | 
| 453 | 
            +
                        )[:,:,None]
         | 
| 454 | 
            +
                else:
         | 
| 455 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 456 | 
            +
                    category = None
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                
         | 
| 459 | 
            +
                mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                masked_image = original_image * (1 - (original_mask>0))
         | 
| 462 | 
            +
                masked_image = masked_image.astype(np.uint8)
         | 
| 463 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                ## not work for image editor
         | 
| 466 | 
            +
                # background = input_image["background"]
         | 
| 467 | 
            +
                # mask_array = original_mask.squeeze()
         | 
| 468 | 
            +
                # layer_rgba = np.array(input_image['layers'][0])
         | 
| 469 | 
            +
                # layer_rgba[mask_array > 0] = [0, 0, 0, 255]  
         | 
| 470 | 
            +
                # layer_rgba = Image.fromarray(layer_rgba, 'RGBA')
         | 
| 471 | 
            +
                # black_image = Image.new("RGBA", layer_rgba.size, (0, 0, 0, 255))
         | 
| 472 | 
            +
                # composite = Image.composite(black_image, background, layer_rgba)
         | 
| 473 | 
            +
                # output_base =  {"layers": [layer_rgba], "background": background, "composite": composite}
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                
         | 
| 476 | 
            +
                return [masked_image], [mask_image], original_mask.astype(np.uint8), category
         | 
| 477 | 
            +
             | 
| 478 | 
            +
             | 
| 479 | 
            +
            def process_random_mask(input_image, original_image, original_mask, resize_and_crop):
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 482 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 483 | 
            +
                if resize_and_crop:
         | 
| 484 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 485 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 486 | 
            +
                    original_image = np.array(original_image)
         | 
| 487 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
             | 
| 490 | 
            +
                if input_mask.max() == 0:
         | 
| 491 | 
            +
                    if original_mask is None:
         | 
| 492 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 493 | 
            +
                    original_mask = original_mask
         | 
| 494 | 
            +
                else:
         | 
| 495 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 496 | 
            +
             | 
| 497 | 
            +
             | 
| 498 | 
            +
                dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
         | 
| 499 | 
            +
                random_mask = random_mask_func(original_mask, dilation_type).squeeze()
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                masked_image = original_image * (1 - (random_mask[:,:,None]>0))
         | 
| 504 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 505 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
             | 
| 508 | 
            +
                return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
             | 
| 511 | 
            +
            def process_dilation_mask(input_image, original_image, original_mask, resize_and_crop):
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 514 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 515 | 
            +
                if resize_and_crop:
         | 
| 516 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 517 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 518 | 
            +
                    original_image = np.array(original_image)
         | 
| 519 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                if input_mask.max() == 0:
         | 
| 522 | 
            +
                    if original_mask is None:
         | 
| 523 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 524 | 
            +
                    original_mask = original_mask
         | 
| 525 | 
            +
                else:
         | 
| 526 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                dilation_type = np.random.choice(['square_dilation'])
         | 
| 529 | 
            +
                random_mask = random_mask_func(original_mask, dilation_type).squeeze()
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                masked_image = original_image * (1 - (random_mask[:,:,None]>0))
         | 
| 534 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 535 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
             | 
| 540 | 
            +
            def process_erosion_mask(input_image, original_image, original_mask, resize_and_crop):
         | 
| 541 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 542 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 543 | 
            +
                if resize_and_crop:
         | 
| 544 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 545 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 546 | 
            +
                    original_image = np.array(original_image)
         | 
| 547 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                if input_mask.max() == 0:
         | 
| 550 | 
            +
                    if original_mask is None:
         | 
| 551 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 552 | 
            +
                    original_mask = original_mask
         | 
| 553 | 
            +
                else:
         | 
| 554 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                dilation_type = np.random.choice(['square_erosion'])
         | 
| 557 | 
            +
                random_mask = random_mask_func(original_mask, dilation_type).squeeze()
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                masked_image = original_image * (1 - (random_mask[:,:,None]>0))
         | 
| 562 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 563 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
             | 
| 566 | 
            +
                return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
         | 
| 567 | 
            +
             | 
| 568 | 
            +
             | 
| 569 | 
            +
            def move_mask_left(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 572 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 573 | 
            +
                if resize_and_crop:
         | 
| 574 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 575 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 576 | 
            +
                    original_image = np.array(original_image)
         | 
| 577 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                if input_mask.max() == 0:
         | 
| 580 | 
            +
                    if original_mask is None:
         | 
| 581 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 582 | 
            +
                    original_mask = original_mask
         | 
| 583 | 
            +
                else:
         | 
| 584 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
         | 
| 587 | 
            +
                mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
         | 
| 590 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 591 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                if moved_mask.max() <= 1:
         | 
| 594 | 
            +
                    moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
         | 
| 595 | 
            +
                    original_mask = moved_mask
         | 
| 596 | 
            +
                return [masked_image], [mask_image], original_mask.astype(np.uint8)
         | 
| 597 | 
            +
             | 
| 598 | 
            +
             | 
| 599 | 
            +
            def move_mask_right(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
         | 
| 600 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 601 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 602 | 
            +
                if resize_and_crop:
         | 
| 603 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 604 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 605 | 
            +
                    original_image = np.array(original_image)
         | 
| 606 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                if input_mask.max() == 0:
         | 
| 609 | 
            +
                    if original_mask is None:
         | 
| 610 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 611 | 
            +
                    original_mask = original_mask
         | 
| 612 | 
            +
                else:
         | 
| 613 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
         | 
| 620 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 621 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
             | 
| 624 | 
            +
                if moved_mask.max() <= 1:
         | 
| 625 | 
            +
                    moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
         | 
| 626 | 
            +
                    original_mask = moved_mask
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                return [masked_image], [mask_image], original_mask.astype(np.uint8)
         | 
| 629 | 
            +
             | 
| 630 | 
            +
             | 
| 631 | 
            +
            def move_mask_up(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
         | 
| 632 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 633 | 
            +
                input_mask = np.asarray(alpha_mask) 
         | 
| 634 | 
            +
                if resize_and_crop:
         | 
| 635 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 636 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 637 | 
            +
                    original_image = np.array(original_image)
         | 
| 638 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                if input_mask.max() == 0:
         | 
| 641 | 
            +
                    if original_mask is None:
         | 
| 642 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 643 | 
            +
                    original_mask = original_mask
         | 
| 644 | 
            +
                else:
         | 
| 645 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()    
         | 
| 648 | 
            +
                mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
         | 
| 651 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 652 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                if moved_mask.max() <= 1:
         | 
| 655 | 
            +
                    moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
         | 
| 656 | 
            +
                    original_mask = moved_mask
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                return [masked_image], [mask_image], original_mask.astype(np.uint8)          
         | 
| 659 | 
            +
             | 
| 660 | 
            +
             | 
| 661 | 
            +
            def move_mask_down(input_image, original_image, original_mask, moving_pixels, resize_and_crop):
         | 
| 662 | 
            +
                alpha_mask = input_image["layers"][0].split()[3]
         | 
| 663 | 
            +
                input_mask = np.asarray(alpha_mask)
         | 
| 664 | 
            +
                if resize_and_crop:
         | 
| 665 | 
            +
                    original_image = crop_and_resize(Image.fromarray(original_image), target_width=640, target_height=640)
         | 
| 666 | 
            +
                    input_mask = crop_and_resize(Image.fromarray(input_mask), target_width=640, target_height=640)
         | 
| 667 | 
            +
                    original_image = np.array(original_image)
         | 
| 668 | 
            +
                    input_mask = np.array(input_mask)
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                if input_mask.max() == 0:
         | 
| 671 | 
            +
                    if original_mask is None:
         | 
| 672 | 
            +
                        raise gr.Error('Please generate mask first')
         | 
| 673 | 
            +
                    original_mask = original_mask
         | 
| 674 | 
            +
                else:
         | 
| 675 | 
            +
                    original_mask = input_mask[:,:,None]
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
         | 
| 678 | 
            +
                mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
         | 
| 679 | 
            +
                     
         | 
| 680 | 
            +
                masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
         | 
| 681 | 
            +
                masked_image = masked_image.astype(original_image.dtype)
         | 
| 682 | 
            +
                masked_image = Image.fromarray(masked_image)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                if moved_mask.max() <= 1:
         | 
| 685 | 
            +
                    moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
         | 
| 686 | 
            +
                    original_mask = moved_mask  
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                return [masked_image], [mask_image], original_mask.astype(np.uint8)
         | 
| 689 | 
            +
             | 
| 690 | 
            +
             | 
| 691 | 
            +
            def store_img(base):
         | 
| 692 | 
            +
                import ipdb; ipdb.set_trace()
         | 
| 693 | 
            +
                image_pil = base["background"].convert("RGB")
         | 
| 694 | 
            +
                original_image = np.array(image_pil)
         | 
| 695 | 
            +
                # import ipdb; ipdb.set_trace()
         | 
| 696 | 
            +
                if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
         | 
| 697 | 
            +
                    raise gr.Error('image aspect ratio cannot be larger than 2.0')
         | 
| 698 | 
            +
                return base, original_image, None, "", None, None, None, None, None  
         | 
| 699 | 
            +
             | 
| 700 | 
            +
             | 
| 701 | 
            +
            def reset_func(input_image, original_image, original_mask, prompt, target_prompt):
         | 
| 702 | 
            +
                input_image = None
         | 
| 703 | 
            +
                original_image = None
         | 
| 704 | 
            +
                original_mask = None
         | 
| 705 | 
            +
                prompt = ''
         | 
| 706 | 
            +
                mask_gallery = []
         | 
| 707 | 
            +
                masked_gallery = []
         | 
| 708 | 
            +
                result_gallery = []
         | 
| 709 | 
            +
                target_prompt = ''
         | 
| 710 | 
            +
                return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt
         | 
| 711 | 
            +
             | 
| 712 | 
            +
             | 
| 713 | 
            +
            block = gr.Blocks(
         | 
| 714 | 
            +
                    theme=gr.themes.Soft(
         | 
| 715 | 
            +
                         radius_size=gr.themes.sizes.radius_none,
         | 
| 716 | 
            +
                         text_size=gr.themes.sizes.text_md
         | 
| 717 | 
            +
                     )
         | 
| 718 | 
            +
                    ).queue()
         | 
| 719 | 
            +
            with block as demo:
         | 
| 720 | 
            +
                with gr.Row():
         | 
| 721 | 
            +
                    with gr.Column(): 
         | 
| 722 | 
            +
                        gr.HTML(head)
         | 
| 723 | 
            +
             | 
| 724 | 
            +
                gr.Markdown(descriptions)
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
         | 
| 727 | 
            +
                    with gr.Row(equal_height=True):
         | 
| 728 | 
            +
                        gr.Markdown(instructions)
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                original_image = gr.State(value=None)
         | 
| 731 | 
            +
                original_mask = gr.State(value=None)
         | 
| 732 | 
            +
                category = gr.State(value=None)
         | 
| 733 | 
            +
                       
         | 
| 734 | 
            +
                with gr.Row():
         | 
| 735 | 
            +
                    with gr.Column():
         | 
| 736 | 
            +
                        with gr.Row():
         | 
| 737 | 
            +
                            input_image = gr.ImageEditor( 
         | 
| 738 | 
            +
                                label="Input Image",
         | 
| 739 | 
            +
                                type="pil",
         | 
| 740 | 
            +
                                brush=gr.Brush(colors=["#000000"], default_size = 30, color_mode="fixed"),
         | 
| 741 | 
            +
                                layers = False,
         | 
| 742 | 
            +
                                interactive=True,
         | 
| 743 | 
            +
                                height=800,
         | 
| 744 | 
            +
                                # transforms=("crop"),
         | 
| 745 | 
            +
                                # crop_size=(640, 640),
         | 
| 746 | 
            +
                                )
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                        prompt = gr.Textbox(label="Prompt", placeholder="Please input your instruction.",value='',lines=1)
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                        with gr.Row():
         | 
| 751 | 
            +
                            mask_button = gr.Button("Generate Mask")
         | 
| 752 | 
            +
                            random_mask_button = gr.Button("Random Generated Mask")
         | 
| 753 | 
            +
                        with gr.Row():
         | 
| 754 | 
            +
                            dilation_mask_button = gr.Button("Dilation Generated Mask")
         | 
| 755 | 
            +
                            erosion_mask_button = gr.Button("Erosion Generated Mask")
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                        with gr.Row():
         | 
| 758 | 
            +
                            generate_target_prompt_button = gr.Button("Generate Target Prompt")
         | 
| 759 | 
            +
                            run_button = gr.Button("Run")
         | 
| 760 | 
            +
             | 
| 761 | 
            +
             | 
| 762 | 
            +
                        target_prompt = gr.Text(
         | 
| 763 | 
            +
                                    label="Target prompt",
         | 
| 764 | 
            +
                                    max_lines=5,
         | 
| 765 | 
            +
                                    placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
         | 
| 766 | 
            +
                                    value='',
         | 
| 767 | 
            +
                                    lines=2
         | 
| 768 | 
            +
                                )
         | 
| 769 | 
            +
                        
         | 
| 770 | 
            +
                        resize_and_crop = gr.Checkbox(label="Resize and Crop (640 x 640)", value=False)
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                        with gr.Accordion("More input params (highly-recommended)", open=False, elem_id="accordion1"):
         | 
| 773 | 
            +
                            negative_prompt = gr.Text(
         | 
| 774 | 
            +
                                    label="Negative Prompt",
         | 
| 775 | 
            +
                                    max_lines=5,
         | 
| 776 | 
            +
                                    placeholder="Please input your negative prompt",
         | 
| 777 | 
            +
                                    value='ugly, low quality',lines=1
         | 
| 778 | 
            +
                                )
         | 
| 779 | 
            +
                                                
         | 
| 780 | 
            +
                            control_strength = gr.Slider(
         | 
| 781 | 
            +
                                label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
         | 
| 782 | 
            +
                                )
         | 
| 783 | 
            +
                            with gr.Group():
         | 
| 784 | 
            +
                                seed = gr.Slider(
         | 
| 785 | 
            +
                                    label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
         | 
| 786 | 
            +
                                )
         | 
| 787 | 
            +
                                randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
         | 
| 788 | 
            +
                            
         | 
| 789 | 
            +
                            blending = gr.Checkbox(label="Blending mode", value=True)
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                            
         | 
| 792 | 
            +
                            num_samples = gr.Slider(
         | 
| 793 | 
            +
                                label="Num samples", minimum=0, maximum=4, step=1, value=4
         | 
| 794 | 
            +
                            )
         | 
| 795 | 
            +
                            
         | 
| 796 | 
            +
                            with gr.Group():
         | 
| 797 | 
            +
                                with gr.Row():
         | 
| 798 | 
            +
                                    guidance_scale = gr.Slider(
         | 
| 799 | 
            +
                                        label="Guidance scale",
         | 
| 800 | 
            +
                                        minimum=1,
         | 
| 801 | 
            +
                                        maximum=12,
         | 
| 802 | 
            +
                                        step=0.1,
         | 
| 803 | 
            +
                                        value=7.5,
         | 
| 804 | 
            +
                                    )
         | 
| 805 | 
            +
                                    num_inference_steps = gr.Slider(
         | 
| 806 | 
            +
                                        label="Number of inference steps",
         | 
| 807 | 
            +
                                        minimum=1,
         | 
| 808 | 
            +
                                        maximum=50,
         | 
| 809 | 
            +
                                        step=1,
         | 
| 810 | 
            +
                                        value=50,
         | 
| 811 | 
            +
                                    )
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                        
         | 
| 814 | 
            +
                    with gr.Column():
         | 
| 815 | 
            +
                        with gr.Row():
         | 
| 816 | 
            +
                            with gr.Tabs(elem_classes=["feedback"]):
         | 
| 817 | 
            +
                                with gr.TabItem("Mask"):
         | 
| 818 | 
            +
                                    mask_gallery = gr.Gallery(label='Mask', show_label=False, elem_id="gallery", preview=True, height=360)
         | 
| 819 | 
            +
                            with gr.Tabs(elem_classes=["feedback"]):
         | 
| 820 | 
            +
                                with gr.TabItem("Masked Image"):
         | 
| 821 | 
            +
                                    masked_gallery = gr.Gallery(label='Masked Image', show_label=False, elem_id="gallery", preview=True, height=360)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                        moving_pixels = gr.Slider(
         | 
| 824 | 
            +
                                label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
         | 
| 825 | 
            +
                                )
         | 
| 826 | 
            +
                        with gr.Row():
         | 
| 827 | 
            +
                            move_left_button = gr.Button("Move Left")
         | 
| 828 | 
            +
                            move_right_button = gr.Button("Move Right")
         | 
| 829 | 
            +
                        with gr.Row():
         | 
| 830 | 
            +
                            move_up_button = gr.Button("Move Up")
         | 
| 831 | 
            +
                            move_down_button = gr.Button("Move Down")
         | 
| 832 | 
            +
                        
         | 
| 833 | 
            +
                        with gr.Tabs(elem_classes=["feedback"]):
         | 
| 834 | 
            +
                            with gr.TabItem("Outputs"):
         | 
| 835 | 
            +
                                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, height=360)
         | 
| 836 | 
            +
             | 
| 837 | 
            +
                        reset_button = gr.Button("Reset")
         | 
| 838 | 
            +
             | 
| 839 | 
            +
             | 
| 840 | 
            +
                with gr.Row():
         | 
| 841 | 
            +
                #     # example = gr.Examples(
         | 
| 842 | 
            +
                #     #     label="Quick Example",
         | 
| 843 | 
            +
                #     #     examples=EXAMPLES,
         | 
| 844 | 
            +
                #     #     inputs=[prompt, seed, result_gallery, mask_gallery, masked_gallery],
         | 
| 845 | 
            +
                #     #     examples_per_page=10,
         | 
| 846 | 
            +
                #     #     cache_examples=False,
         | 
| 847 | 
            +
                #     # )
         | 
| 848 | 
            +
                    example = gr.Examples(
         | 
| 849 | 
            +
                        label="Quick Example",
         | 
| 850 | 
            +
                        examples=EXAMPLES,
         | 
| 851 | 
            +
                        inputs=[input_image, prompt],
         | 
| 852 | 
            +
                        examples_per_page=10,
         | 
| 853 | 
            +
                        cache_examples=False,
         | 
| 854 | 
            +
                    )
         | 
| 855 | 
            +
                    # def process_example(prompt, seed, eg_output):
         | 
| 856 | 
            +
                    #     import ipdb; ipdb.set_trace()
         | 
| 857 | 
            +
                    #     eg_output_path = os.path.join("assets/", eg_output)
         | 
| 858 | 
            +
                    #     return prompt, seed, [Image.open(eg_output_path)]
         | 
| 859 | 
            +
                    # example = gr.Examples(
         | 
| 860 | 
            +
                    #     label="Quick Example",
         | 
| 861 | 
            +
                    #     examples=EXAMPLES,
         | 
| 862 | 
            +
                    #     inputs=[prompt, seed, eg_output],
         | 
| 863 | 
            +
                    #     outputs=[prompt, seed, result_gallery],
         | 
| 864 | 
            +
                    #     fn=process_example,
         | 
| 865 | 
            +
                    #     examples_per_page=10,
         | 
| 866 | 
            +
                    #     run_on_click=True,
         | 
| 867 | 
            +
                    #     cache_examples=False,
         | 
| 868 | 
            +
                    # )
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                input_image.upload(
         | 
| 871 | 
            +
                    store_img,
         | 
| 872 | 
            +
                    [input_image],
         | 
| 873 | 
            +
                    [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt]
         | 
| 874 | 
            +
                ) 
         | 
| 875 | 
            +
             | 
| 876 | 
            +
             | 
| 877 | 
            +
                ips=[input_image, 
         | 
| 878 | 
            +
                     original_image, 
         | 
| 879 | 
            +
                     original_mask, 
         | 
| 880 | 
            +
                     prompt, 
         | 
| 881 | 
            +
                     negative_prompt, 
         | 
| 882 | 
            +
                     control_strength, 
         | 
| 883 | 
            +
                     seed, 
         | 
| 884 | 
            +
                     randomize_seed, 
         | 
| 885 | 
            +
                     guidance_scale, 
         | 
| 886 | 
            +
                     num_inference_steps,
         | 
| 887 | 
            +
                     num_samples,
         | 
| 888 | 
            +
                     blending,
         | 
| 889 | 
            +
                     category,
         | 
| 890 | 
            +
                     target_prompt,
         | 
| 891 | 
            +
                     resize_and_crop]
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                ## run brushedit
         | 
| 894 | 
            +
                run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, target_prompt])
         | 
| 895 | 
            +
                
         | 
| 896 | 
            +
                ## mask func
         | 
| 897 | 
            +
                mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask, category])
         | 
| 898 | 
            +
                random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
         | 
| 899 | 
            +
                dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[ masked_gallery, mask_gallery, original_mask])
         | 
| 900 | 
            +
                erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_and_crop], outputs=[ masked_gallery, mask_gallery, original_mask])
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                ## move mask func
         | 
| 903 | 
            +
                move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
         | 
| 904 | 
            +
                move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
         | 
| 905 | 
            +
                move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])
         | 
| 906 | 
            +
                move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_and_crop], outputs=[masked_gallery, mask_gallery, original_mask])    
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                ## prompt func
         | 
| 909 | 
            +
                generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
         | 
| 910 | 
            +
                
         | 
| 911 | 
            +
                ## reset func
         | 
| 912 | 
            +
                reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt])
         | 
| 913 | 
            +
             | 
| 914 | 
            +
            demo.launch(server_name="0.0.0.0")
         | 
    	
        app/gpt4_o/instructions.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            def create_editing_category_messages(editing_prompt):
         | 
| 2 | 
            +
                messages = [{
         | 
| 3 | 
            +
                        "role": "system",
         | 
| 4 | 
            +
                        "content": [
         | 
| 5 | 
            +
                            {
         | 
| 6 | 
            +
                            "type": "text",
         | 
| 7 | 
            +
                            "text": "I will give you an image and an editing instruction of the image. Please output which type of editing category it is in. You can choose from the following categories: \n\
         | 
| 8 | 
            +
                1. Addition: Adding new objects within the images, e.g., add a bird to the image \n\
         | 
| 9 | 
            +
                2. Remove: Removing objects, e.g., remove the mask \n\
         | 
| 10 | 
            +
                3. Local: Replace local parts of an object and later the object's attributes (e.g., make it smile) or alter an object's visual appearance without affecting its structure (e.g., change the cat to a dog) \n\
         | 
| 11 | 
            +
                4. Global: Edit the entire image, e.g., let's see it in winter \n\
         | 
| 12 | 
            +
                5. Background: Change the scene's background, e.g., have her walk on water, change the background to a beach, make the hedgehog in France, etc.",
         | 
| 13 | 
            +
                            },]
         | 
| 14 | 
            +
                        },
         | 
| 15 | 
            +
                        {
         | 
| 16 | 
            +
                        "role": "user",
         | 
| 17 | 
            +
                        "content": [
         | 
| 18 | 
            +
                            {
         | 
| 19 | 
            +
                            "type": "text",
         | 
| 20 | 
            +
                            "text": editing_prompt
         | 
| 21 | 
            +
                            },
         | 
| 22 | 
            +
                        ]
         | 
| 23 | 
            +
                        }]
         | 
| 24 | 
            +
                return messages
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def create_ori_object_messages(editing_prompt):
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                messages =  [
         | 
| 30 | 
            +
                            {
         | 
| 31 | 
            +
                            "role": "system",
         | 
| 32 | 
            +
                            "content": [
         | 
| 33 | 
            +
                                {
         | 
| 34 | 
            +
                                "type": "text",
         | 
| 35 | 
            +
                                "text": "I will give you an editing instruction of the image. Please output the object needed to be edited. You only need to output the basic description of the object in no more than 5 words.  The output should only contain one noun. \n \
         | 
| 36 | 
            +
                                For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
         | 
| 37 | 
            +
                                },]
         | 
| 38 | 
            +
                            },
         | 
| 39 | 
            +
                            {
         | 
| 40 | 
            +
                            "role": "user",
         | 
| 41 | 
            +
                            "content": [
         | 
| 42 | 
            +
                                {
         | 
| 43 | 
            +
                                "type": "text",
         | 
| 44 | 
            +
                                "text": editing_prompt
         | 
| 45 | 
            +
                                }
         | 
| 46 | 
            +
                            ]
         | 
| 47 | 
            +
                            }
         | 
| 48 | 
            +
                        ]
         | 
| 49 | 
            +
                return messages
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def create_add_object_messages(editing_prompt, base64_image, height=640, width=640):
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                size_str = f"The image size is height {height}px and width {width}px. The top - left corner is coordinate [0 , 0]. The bottom - right corner is coordinnate [{height} , {width}]. "
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                messages = [
         | 
| 57 | 
            +
                            {
         | 
| 58 | 
            +
                            "role": "user",
         | 
| 59 | 
            +
                            "content": [
         | 
| 60 | 
            +
                                {
         | 
| 61 | 
            +
                                "type": "text",
         | 
| 62 | 
            +
                                "text": "I need to add an object to the image following the instruction: " + editing_prompt + ". " + size_str + " \n \
         | 
| 63 | 
            +
                                Can you give me a possible bounding box of the location for the added object? Please output with the format of [top - left x coordinate , top - left y coordinate , box width , box height]. You should only output the bounding box position and nothing else. Please refer to the example below for the desired format.\n\
         | 
| 64 | 
            +
                                [Examples]\n \
         | 
| 65 | 
            +
                                [19, 101, 32, 153]\n  \
         | 
| 66 | 
            +
                                [54, 12, 242, 96]"
         | 
| 67 | 
            +
                                },
         | 
| 68 | 
            +
                                {
         | 
| 69 | 
            +
                                "type": "image_url",
         | 
| 70 | 
            +
                                "image_url": {
         | 
| 71 | 
            +
                                    "url":f"data:image/jpeg;base64,{base64_image}"
         | 
| 72 | 
            +
                                    },
         | 
| 73 | 
            +
                                }
         | 
| 74 | 
            +
                                    ]
         | 
| 75 | 
            +
                                    }
         | 
| 76 | 
            +
                                ]
         | 
| 77 | 
            +
                return messages
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def create_apply_editing_messages(editing_prompt, base64_image):
         | 
| 81 | 
            +
                messages =  [
         | 
| 82 | 
            +
                        {
         | 
| 83 | 
            +
                        "role": "system",
         | 
| 84 | 
            +
                        "content": [
         | 
| 85 | 
            +
                            {
         | 
| 86 | 
            +
                            "type": "text",
         | 
| 87 | 
            +
                            "text": "I will provide an image along with an editing instruction. Please describe the new content that should be present in the image after applying the instruction. \n \
         | 
| 88 | 
            +
                                For example, if the original image content shows a grandmother wearing a mask and the instruction is 'remove the mask', your output should be: 'a grandmother'. The output should only include elements that remain in the image after the edit and should not mention elements that have been changed or removed, such as 'mask' in this example. Do not output 'sorry, xxx', even if it's a guess, directly output the answer you think is correct."
         | 
| 89 | 
            +
                            },]
         | 
| 90 | 
            +
                        },      
         | 
| 91 | 
            +
                        {
         | 
| 92 | 
            +
                        "role": "user",
         | 
| 93 | 
            +
                        "content": [
         | 
| 94 | 
            +
                            {
         | 
| 95 | 
            +
                            "type": "text",
         | 
| 96 | 
            +
                            "text": editing_prompt
         | 
| 97 | 
            +
                            },
         | 
| 98 | 
            +
                            {"type": "image_url",
         | 
| 99 | 
            +
                            "image_url": {
         | 
| 100 | 
            +
                                "url":f"data:image/jpeg;base64,{base64_image}"
         | 
| 101 | 
            +
                                },
         | 
| 102 | 
            +
                            }, 
         | 
| 103 | 
            +
                        ]
         | 
| 104 | 
            +
                        }
         | 
| 105 | 
            +
                    ]
         | 
| 106 | 
            +
                return messages
         | 
    	
        app/gpt4_o/requirements.txt
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torchvision
         | 
| 2 | 
            +
            transformers>=4.25.1
         | 
| 3 | 
            +
            ftfy
         | 
| 4 | 
            +
            tensorboard
         | 
| 5 | 
            +
            datasets
         | 
| 6 | 
            +
            Pillow==9.5.0 
         | 
| 7 | 
            +
            opencv-python 
         | 
| 8 | 
            +
            imgaug 
         | 
| 9 | 
            +
            accelerate==0.20.3
         | 
| 10 | 
            +
            image-reward
         | 
| 11 | 
            +
            hpsv2
         | 
| 12 | 
            +
            torchmetrics
         | 
| 13 | 
            +
            open-clip-torch
         | 
| 14 | 
            +
            clip
         | 
| 15 | 
            +
            # gradio==4.44.1
         | 
| 16 | 
            +
            gradio==4.38.1
         | 
| 17 | 
            +
            segment_anything
         | 
| 18 | 
            +
            openai
         | 
    	
        app/gpt4_o/run_app.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            export PYTHONPATH=.:$PYTHONPATH
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            export CUDA_VISIBLE_DEVICES=0
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            python app/gpt4_o/brushedit_app.py
         | 
    	
        app/gpt4_o/vlm_pipeline.py
    ADDED
    
    | @@ -0,0 +1,138 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            from io import BytesIO
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import gradio as gr
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            from app.gpt4_o.instructions import (
         | 
| 12 | 
            +
                create_editing_category_messages, 
         | 
| 13 | 
            +
                create_ori_object_messages, 
         | 
| 14 | 
            +
                create_add_object_messages,
         | 
| 15 | 
            +
                create_apply_editing_messages)
         | 
| 16 | 
            +
                
         | 
| 17 | 
            +
            from app.utils.utils import run_grounded_sam
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def encode_image(img):
         | 
| 21 | 
            +
                img = Image.fromarray(img.astype('uint8'))
         | 
| 22 | 
            +
                buffered = BytesIO()
         | 
| 23 | 
            +
                img.save(buffered, format="PNG")
         | 
| 24 | 
            +
                img_bytes = buffered.getvalue()
         | 
| 25 | 
            +
                return base64.b64encode(img_bytes).decode('utf-8')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def run_gpt4o_vl_inference(vlm, 
         | 
| 29 | 
            +
                                       messages):
         | 
| 30 | 
            +
                response = vlm.chat.completions.create(
         | 
| 31 | 
            +
                    model="gpt-4o-2024-08-06",
         | 
| 32 | 
            +
                    messages=messages
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
                response_str = response.choices[0].message.content
         | 
| 35 | 
            +
                return response_str
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def vlm_response_editing_type(vlm, 
         | 
| 39 | 
            +
                                          image, 
         | 
| 40 | 
            +
                                          editing_prompt):
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                base64_image = encode_image(image)
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                messages = create_editing_category_messages(editing_prompt)
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                response_str = run_gpt4o_vl_inference(vlm, messages)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                for category_name in ["Addition","Remove","Local","Global","Background"]:
         | 
| 49 | 
            +
                    if category_name.lower() in response_str.lower():
         | 
| 50 | 
            +
                        return category_name
         | 
| 51 | 
            +
                raise ValueError("Please input correct commands, including add, delete, and modify commands.")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def vlm_response_object_wait_for_edit(vlm, 
         | 
| 55 | 
            +
                                                  category, 
         | 
| 56 | 
            +
                                                  editing_prompt):
         | 
| 57 | 
            +
                if category in ["Background", "Global", "Addition"]:
         | 
| 58 | 
            +
                    edit_object = "nan"
         | 
| 59 | 
            +
                    return edit_object
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                messages = create_ori_object_messages(editing_prompt)
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                response_str = run_gpt4o_vl_inference(vlm, messages)
         | 
| 64 | 
            +
                return response_str
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def vlm_response_mask(vlm, 
         | 
| 68 | 
            +
                                  category, 
         | 
| 69 | 
            +
                                  image, 
         | 
| 70 | 
            +
                                  editing_prompt, 
         | 
| 71 | 
            +
                                  object_wait_for_edit, 
         | 
| 72 | 
            +
                                  sam=None,
         | 
| 73 | 
            +
                                  sam_predictor=None,
         | 
| 74 | 
            +
                                  sam_automask_generator=None,
         | 
| 75 | 
            +
                                  groundingdino_model=None,
         | 
| 76 | 
            +
                                  ):
         | 
| 77 | 
            +
                mask = None
         | 
| 78 | 
            +
                if editing_prompt is None or len(editing_prompt)==0:
         | 
| 79 | 
            +
                    raise gr.Error("Please input the editing instruction!")
         | 
| 80 | 
            +
                height, width = image.shape[:2]
         | 
| 81 | 
            +
                if category=="Addition":
         | 
| 82 | 
            +
                    base64_image = encode_image(image)
         | 
| 83 | 
            +
                    messages = create_add_object_messages(editing_prompt, base64_image, height=height, width=width)
         | 
| 84 | 
            +
                    try:
         | 
| 85 | 
            +
                        response_str = run_gpt4o_vl_inference(vlm, messages)
         | 
| 86 | 
            +
                        pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
         | 
| 87 | 
            +
                        box = re.findall(pattern, response_str)
         | 
| 88 | 
            +
                        box = box[0][1:-1].split(",")
         | 
| 89 | 
            +
                        for i in range(len(box)):
         | 
| 90 | 
            +
                            box[i] = int(box[i])
         | 
| 91 | 
            +
                        cus_mask = np.zeros((height, width))
         | 
| 92 | 
            +
                        cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
         | 
| 93 | 
            +
                        mask = cus_mask
         | 
| 94 | 
            +
                    except:
         | 
| 95 | 
            +
                        raise gr.Error("Please set the mask manually, MLLM cannot output the mask!")
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                elif category=="Background":
         | 
| 98 | 
            +
                    labels = "background"
         | 
| 99 | 
            +
                elif category=="Global":
         | 
| 100 | 
            +
                    mask = 255 * np.zeros((height, width))
         | 
| 101 | 
            +
                else:
         | 
| 102 | 
            +
                    labels = object_wait_for_edit
         | 
| 103 | 
            +
                
         | 
| 104 | 
            +
                if mask is None:
         | 
| 105 | 
            +
                    for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
         | 
| 106 | 
            +
                        try:
         | 
| 107 | 
            +
                            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 108 | 
            +
                            detections = run_grounded_sam(
         | 
| 109 | 
            +
                                input_image={"image":Image.fromarray(image.astype('uint8')),
         | 
| 110 | 
            +
                                             "mask":None}, 
         | 
| 111 | 
            +
                                text_prompt=labels, 
         | 
| 112 | 
            +
                                task_type="seg", 
         | 
| 113 | 
            +
                                box_threshold=thresh, 
         | 
| 114 | 
            +
                                text_threshold=0.25, 
         | 
| 115 | 
            +
                                iou_threshold=0.5, 
         | 
| 116 | 
            +
                                scribble_mode="split",
         | 
| 117 | 
            +
                                sam=sam,
         | 
| 118 | 
            +
                                sam_predictor=sam_predictor,
         | 
| 119 | 
            +
                                sam_automask_generator=sam_automask_generator,
         | 
| 120 | 
            +
                                groundingdino_model=groundingdino_model,
         | 
| 121 | 
            +
                                device=device,
         | 
| 122 | 
            +
                            )
         | 
| 123 | 
            +
                            mask = np.array(detections[0,0,...].cpu()) * 255
         | 
| 124 | 
            +
                            break
         | 
| 125 | 
            +
                        except:
         | 
| 126 | 
            +
                            print(f"wrong in threshhold: {thresh}, continue")
         | 
| 127 | 
            +
                            continue
         | 
| 128 | 
            +
                return mask
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def vlm_response_prompt_after_apply_instruction(vlm, 
         | 
| 132 | 
            +
                                                            image, 
         | 
| 133 | 
            +
                                                            editing_prompt):
         | 
| 134 | 
            +
                base64_image = encode_image(image)        
         | 
| 135 | 
            +
                messages = create_apply_editing_messages(editing_prompt, base64_image)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                response_str = run_gpt4o_vl_inference(vlm, messages)
         | 
| 138 | 
            +
                return response_str
         | 
    	
        app/utils/utils.py
    ADDED
    
    | @@ -0,0 +1,197 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torchvision
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from scipy import ndimage
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # BLIP
         | 
| 8 | 
            +
            from transformers import BlipProcessor, BlipForConditionalGeneration
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # SAM
         | 
| 11 | 
            +
            from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # GroundingDINO
         | 
| 14 | 
            +
            from groundingdino.datasets import transforms as T
         | 
| 15 | 
            +
            from groundingdino.models import build_model
         | 
| 16 | 
            +
            from groundingdino.util.slconfig import SLConfig
         | 
| 17 | 
            +
            from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def load_grounding_dino_model(model_config_path, model_checkpoint_path, device):
         | 
| 21 | 
            +
                args = SLConfig.fromfile(model_config_path)
         | 
| 22 | 
            +
                args.device = device
         | 
| 23 | 
            +
                model = build_model(args)
         | 
| 24 | 
            +
                checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
         | 
| 25 | 
            +
                load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
         | 
| 26 | 
            +
                print(load_res)
         | 
| 27 | 
            +
                _ = model.eval()
         | 
| 28 | 
            +
                return model
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def generate_caption(processor, blip_model, raw_image, device):
         | 
| 32 | 
            +
                # unconditional image captioning
         | 
| 33 | 
            +
                inputs = processor(raw_image, return_tensors="pt").to(device, torch.float16)
         | 
| 34 | 
            +
                out = blip_model.generate(**inputs)
         | 
| 35 | 
            +
                caption = processor.decode(out[0], skip_special_tokens=True)
         | 
| 36 | 
            +
                return caption
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def transform_image(image_pil):
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                transform = T.Compose(
         | 
| 43 | 
            +
                    [
         | 
| 44 | 
            +
                        T.RandomResize([800], max_size=1333),
         | 
| 45 | 
            +
                        T.ToTensor(),
         | 
| 46 | 
            +
                        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
         | 
| 47 | 
            +
                    ]
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                image, _ = transform(image_pil, None)  # 3, h, w
         | 
| 50 | 
            +
                return image
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
         | 
| 54 | 
            +
                caption = caption.lower()
         | 
| 55 | 
            +
                caption = caption.strip()
         | 
| 56 | 
            +
                if not caption.endswith("."):
         | 
| 57 | 
            +
                    caption = caption + "."
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                with torch.no_grad():
         | 
| 60 | 
            +
                    outputs = model(image[None], captions=[caption])
         | 
| 61 | 
            +
                logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
         | 
| 62 | 
            +
                boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
         | 
| 63 | 
            +
                logits.shape[0]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                # filter output
         | 
| 66 | 
            +
                logits_filt = logits.clone()
         | 
| 67 | 
            +
                boxes_filt = boxes.clone()
         | 
| 68 | 
            +
                filt_mask = logits_filt.max(dim=1)[0] > box_threshold
         | 
| 69 | 
            +
                logits_filt = logits_filt[filt_mask]  # num_filt, 256
         | 
| 70 | 
            +
                boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
         | 
| 71 | 
            +
                logits_filt.shape[0]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # get phrase
         | 
| 74 | 
            +
                tokenlizer = model.tokenizer
         | 
| 75 | 
            +
                tokenized = tokenlizer(caption)
         | 
| 76 | 
            +
                # build pred
         | 
| 77 | 
            +
                pred_phrases = []
         | 
| 78 | 
            +
                scores = []
         | 
| 79 | 
            +
                for logit, box in zip(logits_filt, boxes_filt):
         | 
| 80 | 
            +
                    pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
         | 
| 81 | 
            +
                    if with_logits:
         | 
| 82 | 
            +
                        pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
         | 
| 83 | 
            +
                    else:
         | 
| 84 | 
            +
                        pred_phrases.append(pred_phrase)
         | 
| 85 | 
            +
                    scores.append(logit.max().item())
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                return boxes_filt, torch.Tensor(scores), pred_phrases
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def run_grounded_sam(input_image, 
         | 
| 92 | 
            +
                                 text_prompt, 
         | 
| 93 | 
            +
                                 task_type, 
         | 
| 94 | 
            +
                                 box_threshold, 
         | 
| 95 | 
            +
                                 text_threshold, 
         | 
| 96 | 
            +
                                 iou_threshold, 
         | 
| 97 | 
            +
                                 scribble_mode,
         | 
| 98 | 
            +
                                 sam,
         | 
| 99 | 
            +
                                 groundingdino_model,
         | 
| 100 | 
            +
                                 sam_predictor=None,
         | 
| 101 | 
            +
                                 sam_automask_generator=None,
         | 
| 102 | 
            +
                                 device="cuda"):
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                global blip_processor, blip_model, inpaint_pipeline
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # load image
         | 
| 107 | 
            +
                image = input_image["image"]
         | 
| 108 | 
            +
                scribble = input_image["mask"]
         | 
| 109 | 
            +
                size = image.size # w, h
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                if sam_predictor is None:
         | 
| 112 | 
            +
                    sam_predictor = SamPredictor(sam)
         | 
| 113 | 
            +
                    sam_automask_generator = SamAutomaticMaskGenerator(sam)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                image_pil = image.convert("RGB")
         | 
| 116 | 
            +
                image = np.array(image_pil)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if task_type == 'scribble':
         | 
| 119 | 
            +
                    sam_predictor.set_image(image)
         | 
| 120 | 
            +
                    scribble = scribble.convert("RGB")
         | 
| 121 | 
            +
                    scribble = np.array(scribble)
         | 
| 122 | 
            +
                    scribble = scribble.transpose(2, 1, 0)[0]
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # 将连通域进行标记
         | 
| 125 | 
            +
                    labeled_array, num_features = ndimage.label(scribble >= 255)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    # 计算每个连通域的质心
         | 
| 128 | 
            +
                    centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
         | 
| 129 | 
            +
                    centers = np.array(centers)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    point_coords = torch.from_numpy(centers)
         | 
| 132 | 
            +
                    point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2])
         | 
| 133 | 
            +
                    point_coords = point_coords.unsqueeze(0).to(device)
         | 
| 134 | 
            +
                    point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device)
         | 
| 135 | 
            +
                    if scribble_mode == 'split':
         | 
| 136 | 
            +
                        point_coords = point_coords.permute(1, 0, 2)
         | 
| 137 | 
            +
                        point_labels = point_labels.permute(1, 0)
         | 
| 138 | 
            +
                    masks, _, _ = sam_predictor.predict_torch(
         | 
| 139 | 
            +
                        point_coords=point_coords if len(point_coords) > 0 else None,
         | 
| 140 | 
            +
                        point_labels=point_labels if len(point_coords) > 0 else None,
         | 
| 141 | 
            +
                        mask_input = None,
         | 
| 142 | 
            +
                        boxes = None,
         | 
| 143 | 
            +
                        multimask_output = False,
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                elif task_type == 'automask':
         | 
| 146 | 
            +
                    masks = sam_automask_generator.generate(image)
         | 
| 147 | 
            +
                else:
         | 
| 148 | 
            +
                    transformed_image = transform_image(image_pil)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if task_type == 'automatic':
         | 
| 151 | 
            +
                        # generate caption and tags
         | 
| 152 | 
            +
                        # use Tag2Text can generate better captions
         | 
| 153 | 
            +
                        # https://huggingface.co/spaces/xinyu1205/Tag2Text
         | 
| 154 | 
            +
                        # but there are some bugs...
         | 
| 155 | 
            +
                        blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
         | 
| 156 | 
            +
                        blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
         | 
| 157 | 
            +
                        text_prompt = generate_caption(blip_processor, blip_model, image_pil, device)
         | 
| 158 | 
            +
                        print(f"Caption: {text_prompt}")
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    # run grounding dino model
         | 
| 161 | 
            +
                    boxes_filt, scores, pred_phrases = get_grounding_output(
         | 
| 162 | 
            +
                        groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # process boxes
         | 
| 166 | 
            +
                    H, W = size[1], size[0]
         | 
| 167 | 
            +
                    for i in range(boxes_filt.size(0)):
         | 
| 168 | 
            +
                        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
         | 
| 169 | 
            +
                        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
         | 
| 170 | 
            +
                        boxes_filt[i][2:] += boxes_filt[i][:2]
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    boxes_filt = boxes_filt.cpu()
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
                    if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
         | 
| 176 | 
            +
                        sam_predictor.set_image(image)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                        if task_type == 'automatic':
         | 
| 179 | 
            +
                            # use NMS to handle overlapped boxes
         | 
| 180 | 
            +
                            print(f"Before NMS: {boxes_filt.shape[0]} boxes")
         | 
| 181 | 
            +
                            nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
         | 
| 182 | 
            +
                            boxes_filt = boxes_filt[nms_idx]
         | 
| 183 | 
            +
                            pred_phrases = [pred_phrases[idx] for idx in nms_idx]
         | 
| 184 | 
            +
                            print(f"After NMS: {boxes_filt.shape[0]} boxes")
         | 
| 185 | 
            +
                            print(f"Revise caption with number: {text_prompt}")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                        masks, _, _ = sam_predictor.predict_torch(
         | 
| 190 | 
            +
                            point_coords = None,
         | 
| 191 | 
            +
                            point_labels = None,
         | 
| 192 | 
            +
                            boxes = transformed_boxes,
         | 
| 193 | 
            +
                            multimask_output = False,
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        return masks
         | 
| 196 | 
            +
                    else:
         | 
| 197 | 
            +
                        print("task_type:{} error!".format(task_type))
         | 
    	
        assets/hedgehog_rm_fg/hedgehog.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rm_fg/image_edit_82314e18-c64c-4003-9ef9-52cebf254b2f_2.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rm_fg/mask_82314e18-c64c-4003-9ef9-52cebf254b2f.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rm_fg/masked_image_82314e18-c64c-4003-9ef9-52cebf254b2f.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rm_fg/prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            648464818: remove the hedgehog.
         | 
    	
        assets/hedgehog_rp_bg/hedgehog.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_bg/prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            648464818: make the hedgehog in Italy.
         | 
    	
        assets/hedgehog_rp_fg/hedgehog.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_fg/image_edit_5cab3448-5a3a-459c-9144-35cca3d34273_0.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_fg/mask_5cab3448-5a3a-459c-9144-35cca3d34273.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_fg/masked_image_5cab3448-5a3a-459c-9144-35cca3d34273.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/hedgehog_rp_fg/prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            648464818: replace the hedgehog to flamingo.
         | 
    	
        assets/mona_lisa/image_edit_aae09152-4495-4332-b691-a0c7bff524be_2.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/mona_lisa/mask_aae09152-4495-4332-b691-a0c7bff524be.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/mona_lisa/masked_image_aae09152-4495-4332-b691-a0c7bff524be.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/mona_lisa/mona_lisa.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/mona_lisa/prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            648464818: add a shining necklace.
         | 
    	
        assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/sunflower_girl/prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            648464818: add a wreath on head..
         | 
    	
        assets/sunflower_girl/sunflower_girl.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch
         | 
| 2 | 
            +
            torchvision
         | 
| 3 | 
            +
            torchaudio
         | 
| 4 | 
            +
            transformers>=4.25.1
         | 
| 5 | 
            +
            gradio==4.38.1
         | 
| 6 | 
            +
            ftfy
         | 
| 7 | 
            +
            tensorboard
         | 
| 8 | 
            +
            datasets
         | 
| 9 | 
            +
            Pillow==9.5.0 
         | 
| 10 | 
            +
            opencv-python 
         | 
| 11 | 
            +
            imgaug 
         | 
| 12 | 
            +
            accelerate==0.20.3
         | 
| 13 | 
            +
            image-reward
         | 
| 14 | 
            +
            hpsv2
         | 
| 15 | 
            +
            torchmetrics
         | 
| 16 | 
            +
            open-clip-torch
         | 
| 17 | 
            +
            clip
         | 
| 18 | 
            +
            segment_anything
         | 
| 19 | 
            +
            git+https://github.com/liyaowei-stu/BrushEdit.git
         | 
| 20 | 
            +
            git+https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/GroundingDINO
         | 
