Spaces:
Running
Running
import requests | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import gradio as gr | |
# 验证 SentencePiece 是否安装 | |
try: | |
import sentencepiece | |
print("SentencePiece is installed successfully!") | |
except ImportError: | |
print("SentencePiece is NOT installed!") | |
# 设置设备 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# 加载 Florence-2 模型和处理器 | |
print("Loading Florence-2 model...") | |
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) | |
print("Florence-2 model loaded successfully.") | |
# 加载 Helsinki-NLP 的翻译模型(英文到中文) | |
print("Loading translation model...") | |
translation_model_name = "Helsinki-NLP/opus-mt-en-zh" | |
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) | |
print("Translation model loaded successfully.") | |
# 加载 Stable Diffusion 模型 | |
print("Loading Stable Diffusion model...") | |
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device) | |
print("Stable Diffusion model loaded successfully.") | |
# 翻译函数 | |
def translate_to_chinese(text): | |
try: | |
# 确保输入是字符串 | |
if not isinstance(text, str): | |
print(f"Input is not a string: {text} (type: {type(text)})") | |
text = str(text) # 强制转换为字符串 | |
print("Input text for translation:", text) | |
tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) | |
translated_tokens = translation_model.generate(**tokenized_text) | |
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
print("Translated text:", translated_text) | |
return translated_text | |
except Exception as e: | |
print("Translation error:", str(e)) | |
return f"Translation error: {str(e)}" | |
# 生成描述并翻译 | |
def generate_caption(image): | |
try: | |
# 如果输入是 URL,下载图片 | |
if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")): | |
print("Downloading image from URL...") | |
try: | |
response = requests.get(image, stream=True, timeout=10) | |
response.raise_for_status() # 检查请求是否成功 | |
image = Image.open(response.raw) | |
print("Image downloaded successfully.") | |
except requests.exceptions.RequestException as e: | |
return f"Failed to download image: {str(e)}", None | |
# 如果输入是文件路径,直接打开图片 | |
else: | |
print("Loading image from file...") | |
image = Image.open(image) | |
# 准备输入 | |
prompt = "<MORE_DETAILED_CAPTION>" | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
# 生成文本 | |
print("Generating caption...") | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
print("Generated text:", generated_text) | |
print("Type of generated text:", type(generated_text)) | |
# 解析生成的文本 | |
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
print("Parsed answer:", parsed_answer) | |
print("Type of parsed answer:", type(parsed_answer)) | |
# 翻译成中文 | |
print("Translating to Chinese...") | |
translated_answer = translate_to_chinese(parsed_answer) | |
print("Translation completed.") | |
return translated_answer, parsed_answer | |
except Exception as e: | |
print("Error:", str(e)) | |
return f"Error: {str(e)}", None | |
# 生成图片 | |
def generate_images_from_prompt(prompt): | |
try: | |
# 生成 4 张图片 | |
images = pipe(prompt, num_images_per_prompt=4).images | |
return images | |
except Exception as e: | |
print("Image generation error:", str(e)) | |
return None | |
# Gradio 界面 | |
def gradio_interface(image): | |
# 生成描述并翻译 | |
translated_answer, parsed_answer = generate_caption(image) | |
if translated_answer.startswith("Error"): | |
return translated_answer, None | |
# 返回翻译后的描述 | |
return translated_answer, None | |
# 生成图片的 Gradio 界面 | |
def generate_images_interface(prompt): | |
# 生成图片 | |
images = generate_images_from_prompt(prompt) | |
if images is None: | |
return None | |
# 返回 4 张图片 | |
return images | |
# 创建 Gradio 应用 | |
with gr.Blocks() as demo: | |
gr.Markdown("# Florence-2 Prompt Generation and Image Generation") | |
with gr.Row(): | |
with gr.Column(): | |
# 输入:上传图片或输入图片 URL | |
image_input = gr.Image(label="Upload Image or Enter Image URL", type="filepath") | |
# 输出:生成的描述(翻译成中文) | |
caption_output = gr.Textbox(label="Generated Caption (Translated to Chinese)") | |
# 按钮:生成描述 | |
generate_caption_button = gr.Button("Generate Caption") | |
with gr.Column(): | |
# 输入:生成的描述(用于生成图片) | |
prompt_input = gr.Textbox(label="Generated Caption (for Image Generation)") | |
# 输出:生成的图片 | |
image_output = gr.Gallery(label="Generated Images") | |
# 按钮:生成图片 | |
generate_images_button = gr.Button("Generate Images") | |
# 绑定事件 | |
generate_caption_button.click(gradio_interface, inputs=image_input, outputs=[caption_output, prompt_input]) | |
generate_images_button.click(generate_images_interface, inputs=prompt_input, outputs=image_output) | |
# 启动 Gradio 应用 | |
print("Launching Gradio app...") | |
demo.launch() |