text_pisture / app.py
ziwaixian009's picture
Create app.py
c409be8 verified
raw
history blame
6.45 kB
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr
# 验证 SentencePiece 是否安装
try:
import sentencepiece
print("SentencePiece is installed successfully!")
except ImportError:
print("SentencePiece is NOT installed!")
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 加载 Florence-2 模型和处理器
print("Loading Florence-2 model...")
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
print("Florence-2 model loaded successfully.")
# 加载 Helsinki-NLP 的翻译模型(英文到中文)
print("Loading translation model...")
translation_model_name = "Helsinki-NLP/opus-mt-en-zh"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
print("Translation model loaded successfully.")
# 加载 Stable Diffusion 模型
print("Loading Stable Diffusion model...")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device)
print("Stable Diffusion model loaded successfully.")
# 翻译函数
def translate_to_chinese(text):
try:
# 确保输入是字符串
if not isinstance(text, str):
print(f"Input is not a string: {text} (type: {type(text)})")
text = str(text) # 强制转换为字符串
print("Input text for translation:", text)
tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
translated_tokens = translation_model.generate(**tokenized_text)
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
print("Translated text:", translated_text)
return translated_text
except Exception as e:
print("Translation error:", str(e))
return f"Translation error: {str(e)}"
# 生成描述并翻译
def generate_caption(image):
try:
# 如果输入是 URL,下载图片
if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
print("Downloading image from URL...")
try:
response = requests.get(image, stream=True, timeout=10)
response.raise_for_status() # 检查请求是否成功
image = Image.open(response.raw)
print("Image downloaded successfully.")
except requests.exceptions.RequestException as e:
return f"Failed to download image: {str(e)}", None
# 如果输入是文件路径,直接打开图片
else:
print("Loading image from file...")
image = Image.open(image)
# 准备输入
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
# 生成文本
print("Generating caption...")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
print("Generated text:", generated_text)
print("Type of generated text:", type(generated_text))
# 解析生成的文本
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
print("Parsed answer:", parsed_answer)
print("Type of parsed answer:", type(parsed_answer))
# 翻译成中文
print("Translating to Chinese...")
translated_answer = translate_to_chinese(parsed_answer)
print("Translation completed.")
return translated_answer, parsed_answer
except Exception as e:
print("Error:", str(e))
return f"Error: {str(e)}", None
# 生成图片
def generate_images_from_prompt(prompt):
try:
# 生成 4 张图片
images = pipe(prompt, num_images_per_prompt=4).images
return images
except Exception as e:
print("Image generation error:", str(e))
return None
# Gradio 界面
def gradio_interface(image):
# 生成描述并翻译
translated_answer, parsed_answer = generate_caption(image)
if translated_answer.startswith("Error"):
return translated_answer, None
# 返回翻译后的描述
return translated_answer, None
# 生成图片的 Gradio 界面
def generate_images_interface(prompt):
# 生成图片
images = generate_images_from_prompt(prompt)
if images is None:
return None
# 返回 4 张图片
return images
# 创建 Gradio 应用
with gr.Blocks() as demo:
gr.Markdown("# Florence-2 Prompt Generation and Image Generation")
with gr.Row():
with gr.Column():
# 输入:上传图片或输入图片 URL
image_input = gr.Image(label="Upload Image or Enter Image URL", type="filepath")
# 输出:生成的描述(翻译成中文)
caption_output = gr.Textbox(label="Generated Caption (Translated to Chinese)")
# 按钮:生成描述
generate_caption_button = gr.Button("Generate Caption")
with gr.Column():
# 输入:生成的描述(用于生成图片)
prompt_input = gr.Textbox(label="Generated Caption (for Image Generation)")
# 输出:生成的图片
image_output = gr.Gallery(label="Generated Images")
# 按钮:生成图片
generate_images_button = gr.Button("Generate Images")
# 绑定事件
generate_caption_button.click(gradio_interface, inputs=image_input, outputs=[caption_output, prompt_input])
generate_images_button.click(generate_images_interface, inputs=prompt_input, outputs=image_output)
# 启动 Gradio 应用
print("Launching Gradio app...")
demo.launch()