import requests from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer from diffusers import StableDiffusionPipeline import torch import gradio as gr # 验证 SentencePiece 是否安装 try: import sentencepiece print("SentencePiece is installed successfully!") except ImportError: print("SentencePiece is NOT installed!") # 设置设备 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # 加载 Florence-2 模型和处理器 print("Loading Florence-2 model...") model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) print("Florence-2 model loaded successfully.") # 加载 Helsinki-NLP 的翻译模型(英文到中文) print("Loading translation model...") translation_model_name = "Helsinki-NLP/opus-mt-en-zh" translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) print("Translation model loaded successfully.") # 加载 Stable Diffusion 模型 print("Loading Stable Diffusion model...") pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device) print("Stable Diffusion model loaded successfully.") # 翻译函数 def translate_to_chinese(text): try: # 确保输入是字符串 if not isinstance(text, str): print(f"Input is not a string: {text} (type: {type(text)})") text = str(text) # 强制转换为字符串 print("Input text for translation:", text) tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) translated_tokens = translation_model.generate(**tokenized_text) translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) print("Translated text:", translated_text) return translated_text except Exception as e: print("Translation error:", str(e)) return f"Translation error: {str(e)}" # 生成描述并翻译 def generate_caption(image): try: # 如果输入是 URL,下载图片 if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")): print("Downloading image from URL...") try: response = requests.get(image, stream=True, timeout=10) response.raise_for_status() # 检查请求是否成功 image = Image.open(response.raw) print("Image downloaded successfully.") except requests.exceptions.RequestException as e: return f"Failed to download image: {str(e)}", None # 如果输入是文件路径,直接打开图片 else: print("Loading image from file...") image = Image.open(image) # 准备输入 prompt = "" inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) # 生成文本 print("Generating caption...") generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3 ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] print("Generated text:", generated_text) print("Type of generated text:", type(generated_text)) # 解析生成的文本 parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) print("Parsed answer:", parsed_answer) print("Type of parsed answer:", type(parsed_answer)) # 翻译成中文 print("Translating to Chinese...") translated_answer = translate_to_chinese(parsed_answer) print("Translation completed.") return translated_answer, parsed_answer except Exception as e: print("Error:", str(e)) return f"Error: {str(e)}", None # 生成图片 def generate_images_from_prompt(prompt): try: # 生成 4 张图片 images = pipe(prompt, num_images_per_prompt=1).images return images except Exception as e: print("Image generation error:", str(e)) return None # Gradio 界面 def gradio_interface(image): # 生成描述并翻译 translated_answer, parsed_answer = generate_caption(image) if translated_answer.startswith("Error"): return translated_answer, None # 返回翻译后的描述 return translated_answer, None # 生成图片的 Gradio 界面 def generate_images_interface(prompt): # 生成图片 images = generate_images_from_prompt(prompt) if images is None: return None # 返回 4 张图片 return images # 创建 Gradio 应用 with gr.Blocks() as demo: gr.Markdown("# Florence-2 Prompt Generation and Image Generation") with gr.Row(): with gr.Column(): # 输入:上传图片或输入图片 URL image_input = gr.Image(label="Upload Image or Enter Image URL", type="filepath") # 输出:生成的描述(翻译成中文) caption_output = gr.Textbox(label="Generated Caption (Translated to Chinese)") # 按钮:生成描述 generate_caption_button = gr.Button("Generate Caption") with gr.Column(): # 输入:生成的描述(用于生成图片) prompt_input = gr.Textbox(label="Generated Caption (for Image Generation)") # 输出:生成的图片 image_output = gr.Gallery(label="Generated Images") # 按钮:生成图片 generate_images_button = gr.Button("Generate Images") # 绑定事件 generate_caption_button.click(gradio_interface, inputs=image_input, outputs=[caption_output, prompt_input]) generate_images_button.click(generate_images_interface, inputs=prompt_input, outputs=image_output) # 启动 Gradio 应用 print("Launching Gradio app...") demo.launch()