Spaces:
Sleeping
Sleeping
# import gradio as gr | |
# def greet(name): | |
# return "Hello " + name + "!!" | |
# demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
# demo.launch() | |
import requests | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer | |
import torch | |
import gradio as gr | |
# 验证 SentencePiece 是否安装 | |
try: | |
import sentencepiece | |
print("SentencePiece is installed successfully!") | |
except ImportError: | |
print("SentencePiece is NOT installed!") | |
# 设置设备 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# 加载 Florence-2 模型和处理器 | |
print("Loading Florence-2 model...") | |
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) | |
print("Florence-2 model loaded successfully.") | |
# 加载 Helsinki-NLP 的翻译模型(英文到中文) | |
print("Loading translation model...") | |
translation_model_name = "Helsinki-NLP/opus-mt-en-zh" | |
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) | |
print("Translation model loaded successfully.") | |
# 翻译函数 | |
def translate_to_chinese(text): | |
try: | |
# 确保输入是字符串 | |
if not isinstance(text, str): | |
print(f"Input is not a string: {text} (type: {type(text)})") | |
text = str(text) # 强制转换为字符串 | |
print("Input text for translation:", text) | |
tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) | |
translated_tokens = translation_model.generate(**tokenized_text) | |
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
print("Translated text:", translated_text) | |
return translated_text | |
except Exception as e: | |
print("Translation error:", str(e)) | |
return f"Translation error: {str(e)}" | |
# 生成描述并翻译 | |
def generate_caption(image): | |
try: | |
# 如果输入是 URL,下载图片 | |
if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")): | |
print("Downloading image from URL...") | |
try: | |
response = requests.get(image, stream=True, timeout=10) | |
response.raise_for_status() # 检查请求是否成功 | |
image = Image.open(response.raw) | |
print("Image downloaded successfully.") | |
except requests.exceptions.RequestException as e: | |
return f"Failed to download image: {str(e)}" | |
# 如果输入是文件路径,直接打开图片 | |
else: | |
print("Loading image from file...") | |
image = Image.open(image) | |
# 准备输入 | |
prompt = "<MORE_DETAILED_CAPTION>" | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
# 生成文本 | |
print("Generating caption...") | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
print("Generated text:", generated_text) | |
print("Type of generated text:", type(generated_text)) | |
# 解析生成的文本 | |
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
print("Parsed answer:", parsed_answer) | |
print("Type of parsed answer:", type(parsed_answer)) | |
# 翻译成中文 | |
print("Translating to Chinese...") | |
translated_answer = translate_to_chinese(parsed_answer) | |
print("Translation completed.") | |
return translated_answer | |
except Exception as e: | |
print("Error:", str(e)) | |
return f"Error: {str(e)}" | |
# Gradio 界面 | |
def gradio_interface(image): | |
result = generate_caption(image) | |
return result | |
# 创建 Gradio 应用 | |
iface = gr.Interface( | |
fn=gradio_interface, # 处理函数 | |
inputs=gr.Image(label="Upload Image or Enter Image URL", type="filepath"), # 输入组件 | |
outputs=gr.Textbox(label="Generated Caption (Translated to Chinese)"), # 输出组件 | |
title="紫外线", # 标题 | |
description="Generate detailed captions for images using Florence-2 model and translate them to Chinese. You can upload an image or provide an image URL.", # 描述 | |
examples=[ | |
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"] | |
] # 示例 | |
) | |
# 启动 Gradio 应用 | |
print("Launching Gradio app...") | |
iface.launch() |