File size: 6,454 Bytes
c409be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr

# 验证 SentencePiece 是否安装
try:
    import sentencepiece
    print("SentencePiece is installed successfully!")
except ImportError:
    print("SentencePiece is NOT installed!")

# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 加载 Florence-2 模型和处理器
print("Loading Florence-2 model...")
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
print("Florence-2 model loaded successfully.")

# 加载 Helsinki-NLP 的翻译模型(英文到中文)
print("Loading translation model...")
translation_model_name = "Helsinki-NLP/opus-mt-en-zh"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
print("Translation model loaded successfully.")

# 加载 Stable Diffusion 模型
print("Loading Stable Diffusion model...")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device)
print("Stable Diffusion model loaded successfully.")

# 翻译函数
def translate_to_chinese(text):
    try:
        # 确保输入是字符串
        if not isinstance(text, str):
            print(f"Input is not a string: {text} (type: {type(text)})")
            text = str(text)  # 强制转换为字符串
        
        print("Input text for translation:", text)
        tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
        translated_tokens = translation_model.generate(**tokenized_text)
        translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
        print("Translated text:", translated_text)
        return translated_text
    except Exception as e:
        print("Translation error:", str(e))
        return f"Translation error: {str(e)}"

# 生成描述并翻译
def generate_caption(image):
    try:
        # 如果输入是 URL,下载图片
        if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
            print("Downloading image from URL...")
            try:
                response = requests.get(image, stream=True, timeout=10)
                response.raise_for_status()  # 检查请求是否成功
                image = Image.open(response.raw)
                print("Image downloaded successfully.")
            except requests.exceptions.RequestException as e:
                return f"Failed to download image: {str(e)}", None
        # 如果输入是文件路径,直接打开图片
        else:
            print("Loading image from file...")
            image = Image.open(image)
        
        # 准备输入
        prompt = "<MORE_DETAILED_CAPTION>"
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
        
        # 生成文本
        print("Generating caption...")
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            do_sample=False,
            num_beams=3
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        print("Generated text:", generated_text)
        print("Type of generated text:", type(generated_text))
        
        # 解析生成的文本
        parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
        print("Parsed answer:", parsed_answer)
        print("Type of parsed answer:", type(parsed_answer))
        
        # 翻译成中文
        print("Translating to Chinese...")
        translated_answer = translate_to_chinese(parsed_answer)
        print("Translation completed.")
        
        return translated_answer, parsed_answer
    except Exception as e:
        print("Error:", str(e))
        return f"Error: {str(e)}", None

# 生成图片
def generate_images_from_prompt(prompt):
    try:
        # 生成 4 张图片
        images = pipe(prompt, num_images_per_prompt=4).images
        return images
    except Exception as e:
        print("Image generation error:", str(e))
        return None

# Gradio 界面
def gradio_interface(image):
    # 生成描述并翻译
    translated_answer, parsed_answer = generate_caption(image)
    if translated_answer.startswith("Error"):
        return translated_answer, None
    
    # 返回翻译后的描述
    return translated_answer, None

# 生成图片的 Gradio 界面
def generate_images_interface(prompt):
    # 生成图片
    images = generate_images_from_prompt(prompt)
    if images is None:
        return None
    
    # 返回 4 张图片
    return images

# 创建 Gradio 应用
with gr.Blocks() as demo:
    gr.Markdown("# Florence-2 Prompt Generation and Image Generation")
    
    with gr.Row():
        with gr.Column():
            # 输入:上传图片或输入图片 URL
            image_input = gr.Image(label="Upload Image or Enter Image URL", type="filepath")
            # 输出:生成的描述(翻译成中文)
            caption_output = gr.Textbox(label="Generated Caption (Translated to Chinese)")
            # 按钮:生成描述
            generate_caption_button = gr.Button("Generate Caption")
        
        with gr.Column():
            # 输入:生成的描述(用于生成图片)
            prompt_input = gr.Textbox(label="Generated Caption (for Image Generation)")
            # 输出:生成的图片
            image_output = gr.Gallery(label="Generated Images")
            # 按钮:生成图片
            generate_images_button = gr.Button("Generate Images")
    
    # 绑定事件
    generate_caption_button.click(gradio_interface, inputs=image_input, outputs=[caption_output, prompt_input])
    generate_images_button.click(generate_images_interface, inputs=prompt_input, outputs=image_output)

# 启动 Gradio 应用
print("Launching Gradio app...")
demo.launch()