ziwaixian009 commited on
Commit
c409be8
·
verified ·
1 Parent(s): bbd0818

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer
4
+ from diffusers import StableDiffusionPipeline
5
+ import torch
6
+ import gradio as gr
7
+
8
+ # 验证 SentencePiece 是否安装
9
+ try:
10
+ import sentencepiece
11
+ print("SentencePiece is installed successfully!")
12
+ except ImportError:
13
+ print("SentencePiece is NOT installed!")
14
+
15
+ # 设置设备
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Using device: {device}")
18
+
19
+ # 加载 Florence-2 模型和处理器
20
+ print("Loading Florence-2 model...")
21
+ model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
22
+ processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
23
+ print("Florence-2 model loaded successfully.")
24
+
25
+ # 加载 Helsinki-NLP 的翻译模型(英文到中文)
26
+ print("Loading translation model...")
27
+ translation_model_name = "Helsinki-NLP/opus-mt-en-zh"
28
+ translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
29
+ translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
30
+ print("Translation model loaded successfully.")
31
+
32
+ # 加载 Stable Diffusion 模型
33
+ print("Loading Stable Diffusion model...")
34
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device)
35
+ print("Stable Diffusion model loaded successfully.")
36
+
37
+ # 翻译函数
38
+ def translate_to_chinese(text):
39
+ try:
40
+ # 确保输入是字符串
41
+ if not isinstance(text, str):
42
+ print(f"Input is not a string: {text} (type: {type(text)})")
43
+ text = str(text) # 强制转换为字符串
44
+
45
+ print("Input text for translation:", text)
46
+ tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
47
+ translated_tokens = translation_model.generate(**tokenized_text)
48
+ translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
49
+ print("Translated text:", translated_text)
50
+ return translated_text
51
+ except Exception as e:
52
+ print("Translation error:", str(e))
53
+ return f"Translation error: {str(e)}"
54
+
55
+ # 生成描述并翻译
56
+ def generate_caption(image):
57
+ try:
58
+ # 如果输入是 URL,下载图片
59
+ if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
60
+ print("Downloading image from URL...")
61
+ try:
62
+ response = requests.get(image, stream=True, timeout=10)
63
+ response.raise_for_status() # 检查请求是否成功
64
+ image = Image.open(response.raw)
65
+ print("Image downloaded successfully.")
66
+ except requests.exceptions.RequestException as e:
67
+ return f"Failed to download image: {str(e)}", None
68
+ # 如果输入是文件路径,直接打开图片
69
+ else:
70
+ print("Loading image from file...")
71
+ image = Image.open(image)
72
+
73
+ # 准备输入
74
+ prompt = "<MORE_DETAILED_CAPTION>"
75
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
76
+
77
+ # 生成文本
78
+ print("Generating caption...")
79
+ generated_ids = model.generate(
80
+ input_ids=inputs["input_ids"],
81
+ pixel_values=inputs["pixel_values"],
82
+ max_new_tokens=1024,
83
+ do_sample=False,
84
+ num_beams=3
85
+ )
86
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
87
+ print("Generated text:", generated_text)
88
+ print("Type of generated text:", type(generated_text))
89
+
90
+ # 解析生成的文本
91
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
92
+ print("Parsed answer:", parsed_answer)
93
+ print("Type of parsed answer:", type(parsed_answer))
94
+
95
+ # 翻译成中文
96
+ print("Translating to Chinese...")
97
+ translated_answer = translate_to_chinese(parsed_answer)
98
+ print("Translation completed.")
99
+
100
+ return translated_answer, parsed_answer
101
+ except Exception as e:
102
+ print("Error:", str(e))
103
+ return f"Error: {str(e)}", None
104
+
105
+ # 生成图片
106
+ def generate_images_from_prompt(prompt):
107
+ try:
108
+ # 生成 4 张图片
109
+ images = pipe(prompt, num_images_per_prompt=4).images
110
+ return images
111
+ except Exception as e:
112
+ print("Image generation error:", str(e))
113
+ return None
114
+
115
+ # Gradio 界面
116
+ def gradio_interface(image):
117
+ # 生成描述并翻译
118
+ translated_answer, parsed_answer = generate_caption(image)
119
+ if translated_answer.startswith("Error"):
120
+ return translated_answer, None
121
+
122
+ # 返回翻译后的描述
123
+ return translated_answer, None
124
+
125
+ # 生成图片的 Gradio ��面
126
+ def generate_images_interface(prompt):
127
+ # 生成图片
128
+ images = generate_images_from_prompt(prompt)
129
+ if images is None:
130
+ return None
131
+
132
+ # 返回 4 张图片
133
+ return images
134
+
135
+ # 创建 Gradio 应用
136
+ with gr.Blocks() as demo:
137
+ gr.Markdown("# Florence-2 Prompt Generation and Image Generation")
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ # 输入:上传图片或输入图片 URL
142
+ image_input = gr.Image(label="Upload Image or Enter Image URL", type="filepath")
143
+ # 输出:生成的描述(翻译成中文)
144
+ caption_output = gr.Textbox(label="Generated Caption (Translated to Chinese)")
145
+ # 按钮:生成描述
146
+ generate_caption_button = gr.Button("Generate Caption")
147
+
148
+ with gr.Column():
149
+ # 输入:生成的描述(用于生成图片)
150
+ prompt_input = gr.Textbox(label="Generated Caption (for Image Generation)")
151
+ # 输出:生成的图片
152
+ image_output = gr.Gallery(label="Generated Images")
153
+ # 按钮:生成图片
154
+ generate_images_button = gr.Button("Generate Images")
155
+
156
+ # 绑定事件
157
+ generate_caption_button.click(gradio_interface, inputs=image_input, outputs=[caption_output, prompt_input])
158
+ generate_images_button.click(generate_images_interface, inputs=prompt_input, outputs=image_output)
159
+
160
+ # 启动 Gradio 应用
161
+ print("Launching Gradio app...")
162
+ demo.launch()