Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,8 +25,7 @@ import torch
|
|
| 25 |
import transformers
|
| 26 |
from PIL import Image
|
| 27 |
from transformers import AutoModel, AutoTokenizer
|
| 28 |
-
import
|
| 29 |
-
from ace_inference import ACEInference
|
| 30 |
from scepter.modules.utils.config import Config
|
| 31 |
from scepter.modules.utils.directory import get_md5
|
| 32 |
from scepter.modules.utils.file_system import FS
|
|
@@ -49,6 +48,9 @@ chat_sty = '\U0001F4AC' # 💬
|
|
| 49 |
video_sty = '\U0001f3a5' # 🎥
|
| 50 |
|
| 51 |
lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
class ChatBotUI(object):
|
|
@@ -94,9 +96,10 @@ class ChatBotUI(object):
|
|
| 94 |
assert len(self.model_choices) > 0
|
| 95 |
if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
|
| 96 |
self.model_name = self.default_model_name
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
| 100 |
self.max_msgs = 20
|
| 101 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
| 102 |
self.gradio_version = version('gradio')
|
|
@@ -540,8 +543,11 @@ class ChatBotUI(object):
|
|
| 540 |
lock.acquire()
|
| 541 |
del self.pipe
|
| 542 |
torch.cuda.empty_cache()
|
| 543 |
-
|
| 544 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 545 |
self.model_name = model_name
|
| 546 |
lock.release()
|
| 547 |
|
|
@@ -829,7 +835,8 @@ class ChatBotUI(object):
|
|
| 829 |
edit_image = None
|
| 830 |
edit_image_mask = None
|
| 831 |
edit_task = ''
|
| 832 |
-
|
|
|
|
| 833 |
print(new_message)
|
| 834 |
imgs = self.pipe(
|
| 835 |
image=edit_image,
|
|
@@ -896,9 +903,9 @@ class ChatBotUI(object):
|
|
| 896 |
}
|
| 897 |
|
| 898 |
buffered = io.BytesIO()
|
| 899 |
-
img.convert('RGB').save(buffered, format='
|
| 900 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 901 |
-
img_str = f'<img src="data:image/
|
| 902 |
|
| 903 |
history.append(
|
| 904 |
(message,
|
|
@@ -1048,17 +1055,17 @@ class ChatBotUI(object):
|
|
| 1048 |
|
| 1049 |
img = imgs[0]
|
| 1050 |
buffered = io.BytesIO()
|
| 1051 |
-
img.convert('RGB').save(buffered, format='
|
| 1052 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 1053 |
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
| 1054 |
history = [(prompt,
|
| 1055 |
f'{pre_info} The generated image is:\n {img_str}')]
|
| 1056 |
|
| 1057 |
img_id = get_md5(img_b64)[:12]
|
| 1058 |
-
save_path = os.path.join(self.cache_dir, f'{img_id}.
|
| 1059 |
img.convert('RGB').save(save_path)
|
| 1060 |
|
| 1061 |
-
return self.get_history(history), gr.update(value=
|
| 1062 |
visible=False), gr.update(value=save_path), gr.update(value=-1)
|
| 1063 |
|
| 1064 |
with self.eg:
|
|
|
|
| 25 |
import transformers
|
| 26 |
from PIL import Image
|
| 27 |
from transformers import AutoModel, AutoTokenizer
|
| 28 |
+
from ace_flux_inference import FluxACEInference
|
|
|
|
| 29 |
from scepter.modules.utils.config import Config
|
| 30 |
from scepter.modules.utils.directory import get_md5
|
| 31 |
from scepter.modules.utils.file_system import FS
|
|
|
|
| 48 |
video_sty = '\U0001f3a5' # 🎥
|
| 49 |
|
| 50 |
lock = threading.Lock()
|
| 51 |
+
inference_dict = {
|
| 52 |
+
"ACE_FLUX": FluxACEInference,
|
| 53 |
+
}
|
| 54 |
|
| 55 |
|
| 56 |
class ChatBotUI(object):
|
|
|
|
| 96 |
assert len(self.model_choices) > 0
|
| 97 |
if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
|
| 98 |
self.model_name = self.default_model_name
|
| 99 |
+
pipe_cfg = self.model_choices[self.default_model_name]
|
| 100 |
+
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
|
| 101 |
+
self.pipe = inference_dict[infer_name]()
|
| 102 |
+
self.pipe.init_from_cfg(pipe_cfg)
|
| 103 |
self.max_msgs = 20
|
| 104 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
| 105 |
self.gradio_version = version('gradio')
|
|
|
|
| 543 |
lock.acquire()
|
| 544 |
del self.pipe
|
| 545 |
torch.cuda.empty_cache()
|
| 546 |
+
torch.cuda.ipc_collect()
|
| 547 |
+
pipe_cfg = self.model_choices[model_name]
|
| 548 |
+
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
|
| 549 |
+
self.pipe = inference_dict[infer_name]()
|
| 550 |
+
self.pipe.init_from_cfg(pipe_cfg)
|
| 551 |
self.model_name = model_name
|
| 552 |
lock.release()
|
| 553 |
|
|
|
|
| 835 |
edit_image = None
|
| 836 |
edit_image_mask = None
|
| 837 |
edit_task = ''
|
| 838 |
+
if new_message == "":
|
| 839 |
+
new_message = "a beautiful girl wear a skirt."
|
| 840 |
print(new_message)
|
| 841 |
imgs = self.pipe(
|
| 842 |
image=edit_image,
|
|
|
|
| 903 |
}
|
| 904 |
|
| 905 |
buffered = io.BytesIO()
|
| 906 |
+
img.convert('RGB').save(buffered, format='JPEG')
|
| 907 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 908 |
+
img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
|
| 909 |
|
| 910 |
history.append(
|
| 911 |
(message,
|
|
|
|
| 1055 |
|
| 1056 |
img = imgs[0]
|
| 1057 |
buffered = io.BytesIO()
|
| 1058 |
+
img.convert('RGB').save(buffered, format='JPEG')
|
| 1059 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 1060 |
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
| 1061 |
history = [(prompt,
|
| 1062 |
f'{pre_info} The generated image is:\n {img_str}')]
|
| 1063 |
|
| 1064 |
img_id = get_md5(img_b64)[:12]
|
| 1065 |
+
save_path = os.path.join(self.cache_dir, f'{img_id}.jpg')
|
| 1066 |
img.convert('RGB').save(save_path)
|
| 1067 |
|
| 1068 |
+
return self.get_history(history), gr.update(value=prompt), gr.update(
|
| 1069 |
visible=False), gr.update(value=save_path), gr.update(value=-1)
|
| 1070 |
|
| 1071 |
with self.eg:
|