Spaces:
Sleeping
Sleeping
| import requests | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import gradio as gr | |
| # 验证 SentencePiece 是否安装 | |
| try: | |
| import sentencepiece | |
| print("SentencePiece is installed successfully!") | |
| except ImportError: | |
| print("SentencePiece is NOT installed!") | |
| # 设置设备 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # 加载 Florence-2 模型和处理器 | |
| print("Loading Florence-2 model...") | |
| model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) | |
| processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) | |
| print("Florence-2 model loaded successfully.") | |
| # 加载 Helsinki-NLP 的翻译模型(英文到中文) | |
| print("Loading translation model...") | |
| translation_model_name = "Helsinki-NLP/opus-mt-en-zh" | |
| translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
| translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) | |
| print("Translation model loaded successfully.") | |
| # 加载 Stable Diffusion 模型 | |
| print("Loading Stable Diffusion model...") | |
| pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device) | |
| print("Stable Diffusion model loaded successfully.") | |
| # 翻译函数 | |
| def translate_to_chinese(text): | |
| try: | |
| # 确保输入是字符串 | |
| if not isinstance(text, str): | |
| print(f"Input is not a string: {text} (type: {type(text)})") | |
| text = str(text) # 强制转换为字符串 | |
| print("Input text for translation:", text) | |
| tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) | |
| translated_tokens = translation_model.generate(**tokenized_text) | |
| translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
| print("Translated text:", translated_text) | |
| return translated_text | |
| except Exception as e: | |
| print("Translation error:", str(e)) | |
| return f"Translation error: {str(e)}" | |
| # 生成描述并翻译 | |
| def generate_caption(image): | |
| try: | |
| # 如果输入是 URL,下载图片 | |
| if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")): | |
| print("Downloading image from URL...") | |
| try: | |
| response = requests.get(image, stream=True, timeout=10) | |
| response.raise_for_status() # 检查请求是否成功 | |
| image = Image.open(response.raw) | |
| print("Image downloaded successfully.") | |
| except requests.exceptions.RequestException as e: | |
| return f"Failed to download image: {str(e)}", None | |
| # 如果输入是文件路径,直接打开图片 | |
| else: | |
| print("Loading image from file...") | |
| image = Image.open(image) | |
| # 准备输入 | |
| prompt = "<MORE_DETAILED_CAPTION>" | |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
| # 生成文本 | |
| print("Generating caption...") | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| num_beams=3 | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| print("Generated text:", generated_text) | |
| print("Type of generated text:", type(generated_text)) | |
| # 解析生成的文本 | |
| parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
| print("Parsed answer:", parsed_answer) | |
| print("Type of parsed answer:", type(parsed_answer)) | |
| # 翻译成中文 | |
| print("Translating to Chinese...") | |
| translated_answer = translate_to_chinese(parsed_answer) | |
| print("Translation completed.") | |
| return translated_answer, parsed_answer | |
| except Exception as e: | |
| print("Error:", str(e)) | |
| return f"Error: {str(e)}", None | |
| # 生成图片 | |
| def generate_images_from_prompt(prompt): | |
| try: | |
| # 生成 4 张图片 | |
| images = pipe(prompt, num_images_per_prompt=4).images | |
| return images | |
| except Exception as e: | |
| print("Image generation error:", str(e)) | |
| return None | |
| # Gradio 界面 | |
| def gradio_interface(image): | |
| # 生成描述并翻译 | |
| translated_answer, parsed_answer = generate_caption(image) | |
| if translated_answer.startswith("Error"): | |
| return translated_answer, None | |
| # 返回翻译后的描述 | |
| return translated_answer, None | |
| # 生成图片的 Gradio 界面 | |
| def generate_images_interface(prompt): | |
| # 生成图片 | |
| images = generate_images_from_prompt(prompt) | |
| if images is None: | |
| return None | |
| # 返回 4 张图片 | |
| return images | |
| # 创建 Gradio 应用 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Florence-2 Prompt Generation and Image Generation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 输入:上传图片或输入图片 URL | |
| image_input = gr.Image(label="Upload Image or Enter Image URL", type="filepath") | |
| # 输出:生成的描述(翻译成中文) | |
| caption_output = gr.Textbox(label="Generated Caption (Translated to Chinese)") | |
| # 按钮:生成描述 | |
| generate_caption_button = gr.Button("Generate Caption") | |
| with gr.Column(): | |
| # 输入:生成的描述(用于生成图片) | |
| prompt_input = gr.Textbox(label="Generated Caption (for Image Generation)") | |
| # 输出:生成的图片 | |
| image_output = gr.Gallery(label="Generated Images") | |
| # 按钮:生成图片 | |
| generate_images_button = gr.Button("Generate Images") | |
| # 绑定事件 | |
| generate_caption_button.click(gradio_interface, inputs=image_input, outputs=[caption_output, prompt_input]) | |
| generate_images_button.click(generate_images_interface, inputs=prompt_input, outputs=image_output) | |
| # 启动 Gradio 应用 | |
| print("Launching Gradio app...") | |
| demo.launch() |