Junjie96 commited on
Commit
9c18e52
·
verified ·
1 Parent(s): 303d3b2

Upload 46 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/storyboard_en.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ **/__pycache__
2
+ workdir
3
+ **/.vscode
4
+ *.sh
5
+ .idea/
6
+ .DS_Store
README.md CHANGED
@@ -11,4 +11,4 @@ license: apache-2.0
11
  short_description: Synthesize images from text prompts and visual references
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
11
  short_description: Synthesize images from text prompts and visual references
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 2025-03-12
4
+ # @Author : Junjie He
5
+ import os
6
+ import time
7
+ import uuid
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ from src.agent import extend_prompt, translate_prompt
14
+ from src.anystory import call_anystory
15
+ from src.matting import ImageUniversalMatting
16
+ from src.util import upload_pil_2_oss
17
+
18
+ if not os.path.exists("models/tf_matting.pb"):
19
+ os.makedirs("models", exist_ok=True)
20
+ os.system(f"wget -O models/tf_matting.pb {os.getenv('MATTING_PATH')}")
21
+ universal_matting = ImageUniversalMatting("models/tf_matting.pb")
22
+
23
+
24
+ def image_matting(pil_image):
25
+ if pil_image.mode == "RGBA":
26
+ mask = np.array(pil_image)[..., -1] > 200
27
+ if np.all(mask):
28
+ mask = ((universal_matting(pil_image)[..., -1] > 200) * 255).astype(np.uint8)
29
+ else:
30
+ mask = ((np.array(pil_image)[..., -1] > 200) * 255).astype(np.uint8)
31
+ else:
32
+ mask = ((universal_matting(pil_image.convert("RGB"))[..., -1] > 200) * 255).astype(np.uint8)
33
+ pil_mask = Image.fromarray(mask)
34
+
35
+ np_image = np.array(pil_image.convert("RGB"))
36
+ np_mask = np.array(pil_mask)[..., None] / 255.
37
+ pil_masked_image = Image.fromarray((np_mask * np_image + (1 - np_mask) * 255.).astype(np.uint8))
38
+ return pil_masked_image, pil_mask
39
+
40
+
41
+ def process(
42
+ pil_subject_A_image=None,
43
+ pil_subject_A_mask=None,
44
+ pil_subject_B_image=None,
45
+ pil_subject_B_mask=None,
46
+ prompt="",
47
+ ):
48
+ request_id = time.strftime('%Y%m%d-', time.localtime(time.time())) + str(uuid.uuid4())
49
+
50
+ if prompt == "":
51
+ raise gr.Error("Please enter your prompt")
52
+
53
+ if pil_subject_A_image is None and pil_subject_B_image is None:
54
+ raise gr.Error("Please upload your reference image(s)")
55
+
56
+ image_urls = []
57
+ if pil_subject_A_image is not None:
58
+ if pil_subject_A_mask is not None:
59
+ if pil_subject_A_mask.size != pil_subject_A_image.size:
60
+ raise gr.Error("Subject [A] image & mask size mismatch")
61
+ pil_subject_A_image = pil_subject_A_image.convert("RGB")
62
+ pil_subject_A_mask = pil_subject_A_mask.convert("L")
63
+ pil_subject_A_image = Image.merge("RGBA", (*pil_subject_A_image.split(), pil_subject_A_mask))
64
+ image_urls.append(upload_pil_2_oss(pil_subject_A_image, name=request_id + "_A.png"))
65
+ if pil_subject_B_image is not None:
66
+ if pil_subject_B_mask is not None:
67
+ if pil_subject_B_mask.size != pil_subject_B_image.size:
68
+ raise gr.Error("Subject [B] image & mask size mismatch")
69
+ pil_subject_B_image = pil_subject_B_image.convert("RGB")
70
+ pil_subject_B_mask = pil_subject_B_mask.convert("L")
71
+ pil_subject_B_image = Image.merge("RGBA", (*pil_subject_B_image.split(), pil_subject_B_mask))
72
+ image_urls.append(upload_pil_2_oss(pil_subject_B_image, name=request_id + "_B.png"))
73
+
74
+ res = call_anystory(image_urls, prompt)[0]
75
+
76
+ return res
77
+
78
+
79
+ def interface():
80
+ with gr.Row(variant="panel"):
81
+ gr.HTML(description + "<br>" + tips)
82
+ with gr.Row(variant="panel"):
83
+ with gr.Column(scale=2, min_width=100):
84
+ with gr.Row(equal_height=False):
85
+ with gr.Column(scale=1, min_width=100):
86
+ with gr.Tab(label="Subject [A]"):
87
+ with gr.Group():
88
+ with gr.Row(equal_height=True):
89
+ with gr.Column(min_width=100):
90
+ pil_subject_A_image = gr.Image(type="pil", label="Subject [A] Reference Image",
91
+ format="png", show_label=True, image_mode="RGBA")
92
+ with gr.Column(min_width=100):
93
+ with gr.Group():
94
+ pil_subject_A_mask = gr.Image(type="pil",
95
+ label="Subject [A] Mask (upload supported)",
96
+ format="png", show_label=True, image_mode="L")
97
+ seg_subject_A = gr.Button(value="Segment Subject")
98
+
99
+ with gr.Column(scale=1, min_width=100):
100
+ with gr.Tab(label="Subject [B]"):
101
+ with gr.Group():
102
+ with gr.Row(equal_height=True):
103
+ with gr.Column(min_width=100):
104
+ pil_subject_B_image = gr.Image(type="pil", label="Subject [B] Reference Image",
105
+ format="png", show_label=True, image_mode="RGBA")
106
+ with gr.Column(min_width=100):
107
+ with gr.Group():
108
+ pil_subject_B_mask = gr.Image(type="pil",
109
+ label="Subject [B] Mask (upload supported)",
110
+ format="png", show_label=True, image_mode="L")
111
+ seg_subject_B = gr.Button(value="Segment Subject")
112
+
113
+ with gr.Group():
114
+ prompt = gr.Textbox(value="", label='Prompt', lines=6, show_label=True)
115
+ en_prompt = gr.Textbox(value="", label='prompt', lines=6, show_label=True, visible=False)
116
+ # prompt_extend_button = gr.Button(value="提示词扩写")
117
+
118
+ with gr.Column(scale=1, min_width=100):
119
+ result_gallery = gr.Image(type="pil", label="Generated Image", visible=True, height=450)
120
+ # result_gallery = gr.Gallery(label='Generated Image', show_label=True, elem_id="gallery", preview=True,
121
+ # format="png", height=450)
122
+ run_button = gr.Button(value="🧑‍🎨 RUN")
123
+
124
+ generated_information = gr.Markdown(label="Generation Details", value="", visible=False)
125
+
126
+ seg_subject_A.click(
127
+ fn=set_image_seg_unfinished, outputs=generated_information
128
+ ).then(
129
+ fn=image_matting, inputs=[pil_subject_A_image], outputs=[pil_subject_A_image, pil_subject_A_mask]
130
+ ).then(
131
+ fn=set_image_seg_finished, outputs=generated_information
132
+ )
133
+
134
+ seg_subject_B.click(
135
+ fn=set_image_seg_unfinished, outputs=generated_information
136
+ ).then(
137
+ fn=image_matting, inputs=[pil_subject_B_image], outputs=[pil_subject_B_image, pil_subject_B_mask]
138
+ ).then(
139
+ fn=set_image_seg_finished, outputs=generated_information
140
+ )
141
+
142
+ # prompt_extend_button.click(
143
+ # fn=set_prompt_extend_unfinished, outputs=generated_information
144
+ # ).then(
145
+ # fn=extend_prompt, inputs=[prompt], outputs=[prompt]
146
+ # ).then(
147
+ # fn=set_prompt_extend_finished, outputs=generated_information
148
+ # )
149
+
150
+ run_button.click(
151
+ fn=set_prompt_translate_unfinished, outputs=generated_information
152
+ ).then(
153
+ fn=translate_prompt, inputs=[prompt], outputs=[en_prompt]
154
+ ).then(
155
+ fn=set_image_generate_unfinished, outputs=generated_information
156
+ ).then(
157
+ fn=process,
158
+ inputs=[pil_subject_A_image, pil_subject_A_mask, pil_subject_B_image, pil_subject_B_mask, en_prompt],
159
+ outputs=[result_gallery]
160
+ ).then(
161
+ fn=set_image_generate_finished, outputs=generated_information
162
+ )
163
+
164
+ with gr.Row():
165
+ examples = [
166
+ [
167
+ "assets/examples/1.webp",
168
+ "assets/examples/1_mask.webp",
169
+ None,
170
+ None,
171
+ "Cartoon style. A sheep is ridough the city, holding a wooden sign that says \"TongYi\".",
172
+ "assets/examples/1_output.webp",
173
+ ],
174
+ [
175
+ "assets/examples/2.webp",
176
+ "assets/examples/2_mask.webp",
177
+ None,
178
+ None,
179
+ "Cartoon style. Sun Wukong stands on a tank, holding up an ancient wooden sign high in the air. The sign reads 'AnyStory'. The background is a cyberpunk-style city sky filled with towering buildings.",
180
+ "assets/examples/2_output.webp",
181
+ ],
182
+ [
183
+ "assets/examples/3.webp",
184
+ "assets/examples/3_mask.webp",
185
+ None,
186
+ None,
187
+ "A modern and stylish Nezha playing an electric guitar, dynamic pose, vibrant colors, fantasy atmosphere, mythical Chinese character with a rock-and-roll twist, red scarf flowing in the wind, traditional elements mixed with contemporary design, cinematic lighting, 4k resolution",
188
+ "assets/examples/3_output.webp",
189
+ ],
190
+ [
191
+ "assets/examples/4.webp",
192
+ "assets/examples/4_mask.webp",
193
+ None,
194
+ None,
195
+ "a man riding a bike on the road",
196
+ "assets/examples/4_output.webp",
197
+ ],
198
+ [
199
+ "assets/examples/7.webp",
200
+ "assets/examples/7_mask.webp",
201
+ None,
202
+ None,
203
+ "Nezha is surrounded by a mysterious purple glow, with a pair of eyes glowing with an eerie red light. Broken talismans and debris float around him, highlighting his demonic nature and authority.",
204
+ "assets/examples/7_output.webp",
205
+ ],
206
+ [
207
+ "assets/examples/8.webp",
208
+ "assets/examples/8_mask.webp",
209
+ None,
210
+ None,
211
+ "The car is driving through a cyberpunk city at night in the middle of a heavy downpour.",
212
+ "assets/examples/8_output.webp",
213
+ ],
214
+ [
215
+ "assets/examples/9.webp",
216
+ "assets/examples/9_mask.webp",
217
+ None,
218
+ None,
219
+ "This cosmetic is placed on a table covered with roses.",
220
+ "assets/examples/9_output.webp",
221
+ ],
222
+ [
223
+ "assets/examples/10.webp",
224
+ "assets/examples/10_mask.webp",
225
+ None,
226
+ None,
227
+ "A little boy model is posing for a photo.",
228
+ "assets/examples/10_output.webp",
229
+ ],
230
+ # [
231
+ # "assets/examples/5_1.webp",
232
+ # "assets/examples/5_1_mask.webp",
233
+ # "assets/examples/5_2.webp",
234
+ # "assets/examples/5_2_mask.webp",
235
+ # "两个小孩骑着一辆炫酷的双人电动车,在热闹的菜市场中穿梭。周围是琳琅满目的蔬菜摊、水果筐和忙碌的摊主,他们表情专注又带点嬉笑,车篮里还装着几根胡萝卜和一把青菜,传统与现代元素在烟火气息中完美交融。",
236
+ # "assets/examples/5_output.webp",
237
+ # ],
238
+ [
239
+ "assets/examples/6_1.webp",
240
+ "assets/examples/6_1_mask.webp",
241
+ "assets/examples/6_2.webp",
242
+ "assets/examples/6_2_mask.webp",
243
+ "Two men are sitting by a wooden table, which is laden with delicious food and a pot of wine. One of the men holds a wine glass, drinking heartily with a bold expression; the other smiles as he pours wine for his companion, both of them engaged in cheerful conversation. In the background is an ancient pavilion surrounded by emerald bamboo groves, with sunlight filtering through the leaves to cast dappled shadows.",
244
+ "assets/examples/6_output.webp",
245
+ ],
246
+ ]
247
+ gr.Examples(
248
+ label="Examples",
249
+ examples=examples,
250
+ inputs=[pil_subject_A_image, pil_subject_A_mask, pil_subject_B_image, pil_subject_B_mask, prompt, result_gallery],
251
+ )
252
+
253
+
254
+ def set_image_seg_unfinished():
255
+ return gr.update(
256
+ visible=True,
257
+ value="<h3>(Unfinished) Extracting Subject Mask...</h3>",
258
+ )
259
+
260
+
261
+ def set_image_seg_finished():
262
+ return gr.update(visible=True, value="<h3>Subject mask ready!</h3>")
263
+
264
+
265
+ def set_prompt_extend_unfinished():
266
+ return gr.update(
267
+ visible=True,
268
+ value="<h3>(Unfinished) Rewriting your prompt... ✍️</h3>",
269
+ )
270
+
271
+
272
+ def set_prompt_extend_finished():
273
+ return gr.update(visible=True, value="<h3>Prompt expanded successfully!</h3>")
274
+
275
+
276
+ def set_prompt_translate_unfinished():
277
+ return gr.update(
278
+ visible=True,
279
+ value="<h3>(Unfinished) Preprocessing...</h3>",
280
+ )
281
+
282
+
283
+ def set_image_generate_unfinished():
284
+ return gr.update(
285
+ visible=True,
286
+ value="<h3>(Unfinished) Generating images...</h3>",
287
+ )
288
+
289
+
290
+ def set_image_generate_finished():
291
+ return gr.update(visible=True, value="<h3>Image generation is completed!</h3>")
292
+
293
+
294
+ if __name__ == "__main__":
295
+ title = r"""
296
+ <div style="text-align: center;">
297
+ <h1> AnyStory: Towards Unified Single and Multiple Subject Personalization in Text-to-Image Generation </h1>
298
+ <h1> V2.0.0 </h1>
299
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
300
+ <a href="https://arxiv.org/pdf/2501.09503"><img src="https://img.shields.io/badge/arXiv-2501.09503-red"></a>
301
+ &nbsp;
302
+ <a href='https://aigcdesigngroup.github.io/AnyStory/'><img src='https://img.shields.io/badge/Project_Page-AnyStory-green' alt='Project Page'></a>
303
+ &nbsp;
304
+ <a href='https://modelscope.cn/studios/damo/studio_anytext'><img src='https://img.shields.io/badge/Demo-ModelScope-blue'></a>
305
+ </div>
306
+ </br>
307
+ </div>
308
+ """
309
+
310
+ title_description = r"""
311
+ Official demo of <b>AnyStory 2</b> 🤗. We will continuously update this demo.
312
+ For technical details, please refer to our tech report: <a href='https://arxiv.org/pdf/2501.09503' target='_blank'><b>AnyStory: Towards Unified Single and Multiple Subject Personalization in Text-to-Image Generation</b></a>. 😊
313
+ """
314
+
315
+ description = r"""🚀🚀🚀 Quick Start:<br>
316
+ 1. Upload subject reference images (clean background; real human IDs unsupported for now), Add prompts (CN/EN supported), and Click "<b>RUN</b>".<br>
317
+ 2. (Recommended) Click "<b>Segment Subject</b>" to create masks (or upload your own B&W masks) for subjects. This helps the model better reference the subject you specify (otherwise, we will perform automatic detection). 🤗<br>
318
+ """
319
+
320
+ tips = r"""💡💡💡 Tips:<br>
321
+ If the subject doesn't appear, try adding a detailed description of the subject in the prompt that matches the reference image, and avoid conflicting details (e.g., significantly altering the subject's appearance). Multi-subject referencing in AnyStory2 is still being optimized. 🤗<br>
322
+ """
323
+
324
+ citation = r"""
325
+ ---
326
+ 📝 **Citation**
327
+ <br>
328
+ If our work is helpful for your research or applications, please cite us via:
329
+ ```bibtex
330
+ @article{he2025anystory,
331
+ title={AnyStory: Towards Unified Single and Multiple Subject Personalization in Text-to-Image Generation},
332
+ author={He, Junjie and Tuo, Yuxiang and Chen, Binghui and Zhong, Chongyang and Geng, Yifeng and Bo, Liefeng},
333
+ journal={arXiv preprint arXiv:2501.09503},
334
+ year={2025}
335
+ }
336
+ ```
337
+ If you have any questions, feel free to open an issue or contact us directly at <b>[email protected]</b>.
338
+ """
339
+
340
+ js = """
341
+ function createGradioAnimation() {
342
+ var container = document.createElement('div');
343
+ container.id = 'gradio-animation';
344
+ container.style.fontSize = '2em';
345
+ container.style.fontWeight = 'bold';
346
+ container.style.textAlign = 'center';
347
+ container.style.marginBottom = '20px';
348
+
349
+ var text = 'Welcome to AnyStory!';
350
+ for (var i = 0; i < text.length; i++) {
351
+ (function(i){
352
+ setTimeout(function(){
353
+ var letter = document.createElement('span');
354
+ letter.style.opacity = '0';
355
+ letter.style.transition = 'opacity 0.5s';
356
+ letter.innerText = text[i];
357
+
358
+ container.appendChild(letter);
359
+
360
+ setTimeout(function() {
361
+ letter.style.opacity = '1';
362
+ }, 50);
363
+ }, i * 250);
364
+ })(i);
365
+ }
366
+
367
+ var gradioContainer = document.querySelector('.gradio-container');
368
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
369
+
370
+ return 'Animation created';
371
+ }
372
+ """
373
+ block = gr.Blocks(title="AnyStory2", js=js, theme=gr.themes.Ocean()).queue()
374
+ with block:
375
+ gr.HTML(title)
376
+ gr.HTML(title_description)
377
+
378
+ interface()
379
+
380
+ gr.HTML("<br>More examples: Intelligent creation of AI story pictures integrated with Qwen Agent")
381
+ gr.Gallery(value=["assets/storyboard_en.png"], columns=1, object_fit="contain", show_label=False)
382
+
383
+ gr.Markdown(citation)
384
+
385
+ block.launch(share=True, max_threads=10)
386
+ # block.launch(server_name='0.0.0.0', share=False, server_port=9999, max_threads=3)
387
+ # block.launch(server_name='127.0.0.1', share=False, server_port=9999, allowed_paths=["/"])
assets/examples/1.webp ADDED
assets/examples/10.webp ADDED
assets/examples/10_mask.webp ADDED
assets/examples/10_output.webp ADDED
assets/examples/1_mask.webp ADDED
assets/examples/1_output.webp ADDED
assets/examples/2.webp ADDED
assets/examples/2_mask.webp ADDED
assets/examples/2_output.webp ADDED
assets/examples/3.webp ADDED
assets/examples/3_mask.webp ADDED
assets/examples/3_output.webp ADDED
assets/examples/4.webp ADDED
assets/examples/4_mask.webp ADDED
assets/examples/4_output.webp ADDED
assets/examples/5_1.webp ADDED
assets/examples/5_1_mask.webp ADDED
assets/examples/5_2.webp ADDED
assets/examples/5_2_mask.webp ADDED
assets/examples/5_output.webp ADDED
assets/examples/6_1.webp ADDED
assets/examples/6_1_mask.webp ADDED
assets/examples/6_2.webp ADDED
assets/examples/6_2_mask.webp ADDED
assets/examples/6_output.webp ADDED
assets/examples/7.webp ADDED
assets/examples/7_mask.webp ADDED
assets/examples/7_output.webp ADDED
assets/examples/8.webp ADDED
assets/examples/8_mask.webp ADDED
assets/examples/8_output.webp ADDED
assets/examples/9.webp ADDED
assets/examples/9_mask.webp ADDED
assets/examples/9_output.webp ADDED
assets/storyboard_en.png ADDED

Git LFS Details

  • SHA256: 9da86495acdcdc791308b44f71fa4999d6925236416bc43fc86fb8bd5d1e556c
  • Pointer size: 132 Bytes
  • Size of remote file: 5.02 MB
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ oss2
2
+ phidata
3
+ dashscope
4
+ tensorflow==2.15
src/__init__.py ADDED
File without changes
src/agent.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import gradio as gr
5
+ from phi.assistant import Assistant
6
+ from phi.llm.openai import OpenAIChat
7
+ from pydantic import BaseModel
8
+
9
+ from .log import logger
10
+
11
+ qwen_model = 'qwen-plus'
12
+ qwen_api_key = os.getenv("QWEN_DS_API_KEY")
13
+ qwen_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
14
+
15
+
16
+ class TranslateFormat(BaseModel):
17
+ en: str
18
+
19
+
20
+ class PormptExtendFormat(BaseModel):
21
+ prompt: str
22
+
23
+
24
+ trans_assist = Assistant(
25
+ llm=OpenAIChat(model=qwen_model, max_tokens=3000, temperature=0.3, api_key=qwen_api_key, base_url=qwen_url),
26
+ description='你是专业的英语翻译,请将用户输入的一句中文翻译成英文',
27
+ output_model=TranslateFormat
28
+ )
29
+
30
+ requirements = [
31
+ "1.对于过于简短的用户输入,在不改变原意及主体外观的前提下,合理补充具象化细节,避免使用过于抽象描述。",
32
+ "2.如果涉及人像,完善与用户输入相符的人物外貌、表情、种族、姿态、景别、穿着、影调、质感等方面描述。",
33
+ "3.匹配符合用户意图且精准详细的风格描述。如果用户未指定,则根据画面选择最恰当的风格,或使用纪实摄影照片。",
34
+ "4.若输入为中文,则整体中文输出;若输入为全英文,则整体英文输出;保留引号中原文以及重要的输入信息,不要改写。",
35
+ "5.在语义完整前提下,改写后prompt字数小于200字,避免prompt冗长。",
36
+ """
37
+ 6.改写后 prompt 示例:
38
+ (1) 日系小清新写真照片,扎着双麻花辫的波西米亚小女孩坐在船边。女孩穿着白色方领泡泡袖连衣裙,裙子上有褶皱和纽扣装饰。她皮肤白皙,五官精致,眼泪汪汪直视镜头。她双手扶船,长发刘海遮住部分额头。背景是清澈明亮的户外场景,可见蓝天、山峦和一些干枯植物。高清写实摄影,近景中心对称构图。\n
39
+ (2) 二次元厚涂动漫插画,一个猫耳东亚萌妹手持文件夹,怒气冲冲走向电脑。她深紫色爆炸头,红色眼睛,头顶有一个粉色光圈。少女身穿深灰色短裙和浅灰色上衣,腰间系着白色系带,胸前佩戴名牌,居中写着黑体中文"紫阳"。室内背景中摆放了很多办公桌。粗线条的日系赛璐璐风格。近景半身略仰视视角。\n
40
+ (3) 美剧艺术海报风格,身穿黄色防护服的老年Walter White坐在金属折叠椅上,头顶无衬线英文写着"Breaking Bad",周围是成堆的美元和蓝色塑料储物箱。他戴着眼镜目光直视镜头,左手拿着一支雪茄,右手放在膝盖上。背景是废弃阴暗的厂房,阳光透过窗户照射进来。画面带有明显颗粒质感纹理。长焦人物平视特写。
41
+ (4) CG game concept digital art featuring three giant mutant crocodiles with wide-open mouths, revealing pink tongues and sharp teeth. Their rough skin resembles grayish-white stone. On the back of the crocodile to the left, lush trees, shrubs, and some thorn-like protrusions grow. The background of the scene shows a dusk sky and a shimmering pond. The overall atmosphere is dark and cold, with a tilted composition that creates a strong sense of depth and layers. \n
42
+ """
43
+ ]
44
+
45
+ prompt_writer = Assistant(
46
+ llm=OpenAIChat(model=qwen_model, max_tokens=8000, temperature=0.3, api_key=qwen_api_key, base_url=qwen_url),
47
+ description="你是一位prompt优化师,旨在将用户输入改写为优质prompt,在不影响原意前提下更高质量生图。",
48
+ instructions=requirements,
49
+ debug_mode=False,
50
+ output_model=PormptExtendFormat,
51
+ )
52
+
53
+
54
+ def contains_chinese(text):
55
+ """
56
+ 判断文本中是否包含中文字符(基于 Unicode 范围)
57
+ """
58
+ # 定义中文的 Unicode 范围(基础汉字 + 扩展区)
59
+ cjk_ranges = [
60
+ (0x4e00, 0x9fff), # CJK 基本汉字区(含常用汉字)
61
+ (0x3400, 0x4dbf), # CJK 扩展 A
62
+ (0x20000, 0x2a6df), # CJK 扩展 B
63
+ (0x2a700, 0x2b73f), # CJK 扩展 C
64
+ (0x2b740, 0x2b81f), # CJK 扩展 D
65
+ (0x2b820, 0x2ceaf), # CJK 扩展 E
66
+ ]
67
+
68
+ for c in text:
69
+ code = ord(c)
70
+ for start, end in cjk_ranges:
71
+ if start <= code <= end:
72
+ return True
73
+ return False
74
+
75
+
76
+ def translate_prompt(prompt, max_attempts=5):
77
+ attempts = 0
78
+ while attempts < max_attempts:
79
+ try:
80
+ if contains_chinese(prompt):
81
+ res = trans_assist.run(prompt)
82
+ logger.info(f"translate Chinese prompt into English: {prompt} -> {res.en}")
83
+ return res.en
84
+ else:
85
+ return prompt
86
+ except Exception as e:
87
+ print(f"尝试 {attempts + 1} 失败:{e}")
88
+ attempts += 1
89
+ if attempts < 5:
90
+ time.sleep(1) # 延迟重试
91
+ raise gr.Error(f"提示词扩写达到最大尝试次数 {max_attempts},任务失败。")
92
+
93
+
94
+ def extend_prompt(prompt, max_attempts=5):
95
+ attempts = 0
96
+ while attempts < max_attempts:
97
+ try:
98
+ res = prompt_writer.run(prompt, stream=False)
99
+ logger.info(f"extend prompt finished: {prompt} -> {res.prompt}.")
100
+ return res.prompt
101
+ except Exception as e:
102
+ print(f"尝试 {attempts + 1} 失败:{e}")
103
+ attempts += 1
104
+ if attempts < 5:
105
+ time.sleep(1) # 延迟重试
106
+ raise gr.Error(f"提示词扩写达到最大尝试次数 {max_attempts},任务失败。")
107
+
108
+
109
+ if __name__ == '__main__':
110
+ # test_trans()
111
+ pass
src/anystory.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+
5
+ import gradio as gr
6
+ import requests
7
+
8
+ from src.log import logger
9
+ from src.util import download_images
10
+
11
+ anystory_url = os.getenv("ANYSTORY_URL")
12
+ anystory_api_key = os.getenv("ANYSTORY_DS_API_KEY")
13
+ anystory_model = os.getenv("ANYSTORY_MODEL")
14
+
15
+
16
+ def call_anystory(image_urls, prompt):
17
+ headers = {
18
+ "Content-Type": "application/json",
19
+ "Accept": "application/json",
20
+ "Authorization": f"Bearer {anystory_api_key}",
21
+ "X-DashScope-Async": "enable",
22
+ "X-DashScope-DataInspection": "enable",
23
+ }
24
+ data = {
25
+ "model": anystory_model,
26
+ "input": {
27
+ "image_urls": image_urls,
28
+ "prompt": prompt
29
+ },
30
+ "parameters": {
31
+ },
32
+ }
33
+
34
+ res = requests.post(anystory_url, data=json.dumps(data), headers=headers)
35
+
36
+ respose_code = res.status_code
37
+ if 200 == respose_code:
38
+ res = json.loads(res.content.decode())
39
+ task_id = res['output']['task_id']
40
+ logger.info(f"task_id: {task_id}: Create request success. Params: {data}")
41
+
42
+ # Async query
43
+ is_running = True
44
+ while is_running:
45
+ res = requests.post(f'https://poc-dashscope.aliyuncs.com/api/v1/tasks/{task_id}', headers=headers)
46
+ respose_code = res.status_code
47
+ if 200 == respose_code:
48
+ res = json.loads(res.content.decode())
49
+ if "SUCCEEDED" == res['output']['task_status']:
50
+ logger.info(f"task_id: {task_id}: Generation task query success.")
51
+ results = res['output']['results']
52
+ img_urls = [x['url'] for x in results]
53
+ logger.info(f"task_id: {task_id}: {res}")
54
+ break
55
+ elif "FAILED" != res['output']['task_status']:
56
+ logger.debug(f"task_id: {task_id}: query result...")
57
+ time.sleep(1)
58
+ else:
59
+ raise gr.Error("Fail to get results from Generation task.")
60
+
61
+ else:
62
+ logger.error(f'task_id: {task_id}: Fail to query task result: {res.content}')
63
+ raise gr.Error("Fail to query task result.")
64
+
65
+ logger.info(f"task_id: {task_id}: download generated images.")
66
+ img_data = download_images(img_urls)
67
+ logger.info(f"task_id: {task_id}: Generate done.")
68
+ else:
69
+ logger.error(f'Fail to create Generation task: {res.content}')
70
+ raise gr.Error("Fail to create Generation task.")
71
+
72
+ return img_data
src/log.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from logging.handlers import RotatingFileHandler
4
+
5
+ log_file_name = "workdir/AnyStory.log"
6
+ os.makedirs(os.path.dirname(log_file_name), exist_ok=True)
7
+
8
+ format = '[%(levelname)s] %(asctime)s "%(filename)s", line %(lineno)d, %(message)s'
9
+ logging.basicConfig(
10
+ format=format,
11
+ datefmt="%Y-%m-%d %H:%M:%S",
12
+ level=logging.INFO)
13
+ logger = logging.getLogger(name="AnyStory-Studio")
14
+
15
+ fh = RotatingFileHandler(log_file_name, maxBytes=20000000, backupCount=3)
16
+ formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
17
+ fh.setFormatter(formatter)
18
+ logger.addHandler(fh)
src/matting.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+
6
+ if tf.__version__ >= '2.0':
7
+ tf = tf.compat.v1
8
+
9
+
10
+ class ImageUniversalMatting:
11
+
12
+ def __init__(self, weight_path):
13
+ super().__init__()
14
+ config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0})
15
+ config.gpu_options.allow_growth = True
16
+ self._session = tf.Session(config=config)
17
+ with self._session.as_default():
18
+ print(f'loading model from {weight_path}')
19
+ with tf.gfile.FastGFile(weight_path, 'rb') as f:
20
+ graph_def = tf.GraphDef()
21
+ graph_def.ParseFromString(f.read())
22
+ tf.import_graph_def(graph_def, name='')
23
+ self.output = self._session.graph.get_tensor_by_name(
24
+ 'output_png:0')
25
+ self.input_name = 'input_image:0'
26
+ print('load model done')
27
+ self._session.graph.finalize()
28
+
29
+ def __call__(self, image):
30
+ output = self.preprocess(image)
31
+ output = self.forward(output)
32
+ output = self.postprocess(output)
33
+ return output
34
+
35
+ def resize_image(self, img, limit_side_len):
36
+ """
37
+ resize image to a size multiple of 32 which is required by the network
38
+ args:
39
+ img(array): array with shape [h, w, c]
40
+ return(tuple):
41
+ img, (ratio_h, ratio_w)
42
+ """
43
+ h, w, _ = img.shape
44
+
45
+ # limit the max side
46
+ if max(h, w) > limit_side_len:
47
+ if h > w:
48
+ ratio = float(limit_side_len) / h
49
+ else:
50
+ ratio = float(limit_side_len) / w
51
+ else:
52
+ ratio = 1.
53
+ resize_h = int(h * ratio)
54
+ resize_w = int(w * ratio)
55
+
56
+ resize_h = int(round(resize_h / 32) * 32)
57
+ resize_w = int(round(resize_w / 32) * 32)
58
+
59
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
60
+
61
+ return img
62
+
63
+ @staticmethod
64
+ def convert_to_ndarray(img):
65
+ if isinstance(img, Image.Image):
66
+ img = np.array(img.convert('RGB'))
67
+ elif isinstance(img, np.ndarray):
68
+ if len(img.shape) == 2:
69
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
70
+ img = img[:, :, ::-1] # convert to rgb
71
+ else:
72
+ raise TypeError(f'input should be either PIL.Image,'
73
+ f' np.array, but got {type(img)}')
74
+ return img
75
+
76
+ def preprocess(self, input, limit_side_len=800):
77
+ img = self.convert_to_ndarray(input) # rgb input
78
+ img = img.astype(float)
79
+ orig_h, orig_w, _ = img.shape
80
+ img = self.resize_image(img, limit_side_len)
81
+ result = {'img': img, 'orig_h': orig_h, 'orig_w': orig_w}
82
+ return result
83
+
84
+ def forward(self, input):
85
+ orig_h, orig_w = input['orig_h'], input['orig_w']
86
+ with self._session.as_default():
87
+ feed_dict = {self.input_name: input['img']}
88
+ output_img = self._session.run(self.output, feed_dict=feed_dict) # RGBA
89
+ # output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA)
90
+ output_img = cv2.resize(output_img, (int(orig_w), int(orig_h)))
91
+ return {"output_img": output_img}
92
+
93
+ def postprocess(self, inputs):
94
+ return inputs["output_img"]
src/util.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import io
3
+ import os
4
+ import time
5
+
6
+ import oss2
7
+ import requests
8
+ from PIL import Image
9
+
10
+ from .log import logger
11
+
12
+ # oss
13
+ oss_ak = os.getenv("OSS_AK")
14
+ oss_sk = os.getenv("OSS_SK")
15
+ oss_bucket = os.getenv("OSS_BUCKET")
16
+ oss_endpoint = os.getenv("OSS_ENDPOINT")
17
+ oss_path = os.getenv("OSS_PATH")
18
+
19
+ bucket = oss2.Bucket(oss2.Auth(oss_ak, oss_sk), oss_endpoint, oss_bucket)
20
+
21
+
22
+ def download_pil_img(index, img_url):
23
+ r = requests.get(img_url, stream=True)
24
+ if r.status_code == 200:
25
+ img = Image.open(io.BytesIO(r.content))
26
+ return (index, img)
27
+ else:
28
+ logger.error(f"Fail to download: {img_url}")
29
+
30
+
31
+ def download_images(img_urls, n=1):
32
+ imgs_pil = [None] * n
33
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
34
+ to_do = []
35
+ for i, url in enumerate(img_urls):
36
+ future = executor.submit(download_pil_img, i, url)
37
+ to_do.append(future)
38
+
39
+ for future in concurrent.futures.as_completed(to_do):
40
+ ret = future.result()
41
+ index, img_pil = ret
42
+ imgs_pil[index] = img_pil
43
+
44
+ return imgs_pil
45
+
46
+
47
+ def resize_img_with_max_size(image, max_side_length=512):
48
+ width, height = image.size
49
+ ratio = max_side_length / max(width, height)
50
+ new_width = int(width * ratio)
51
+ new_height = int(height * ratio)
52
+ resized_image = image.resize((new_width, new_height))
53
+ return resized_image
54
+
55
+
56
+ def upload_pil_2_oss(pil_image, name="cache.jpg"):
57
+ image = resize_img_with_max_size(pil_image, max_side_length=512)
58
+
59
+ imgByteArr = io.BytesIO()
60
+ if name.lower().endswith(".png"):
61
+ image.save(imgByteArr, format="PNG")
62
+ else:
63
+ image.save(imgByteArr, format="JPEG", quality=95)
64
+ imgByteArr = imgByteArr.getvalue()
65
+
66
+ start_time = time.perf_counter()
67
+ bucket.put_object(oss_path + "/" + name, imgByteArr)
68
+ ret = bucket.sign_url('GET', oss_path + "/" + name, 60 * 60 * 24)
69
+ logger.info(f"upload cost: {time.perf_counter() - start_time} s.")
70
+ del imgByteArr
71
+ return ret