Spaces:
Running
Running
File size: 5,029 Bytes
5cf2b53 68c33ac 5cf2b53 68c33ac 5cf2b53 8c9c8c2 5cf2b53 b71ac5f 5cf2b53 8c9c8c2 5cf2b53 8c9c8c2 5cf2b53 8c9c8c2 5cf2b53 8c9c8c2 8f11c9f 8c9c8c2 b71ac5f 8c9c8c2 ded79ec 8c9c8c2 b71ac5f 8c9c8c2 b71ac5f 8c9c8c2 b71ac5f 8c9c8c2 8f11c9f 8c9c8c2 b71ac5f 8c9c8c2 b71ac5f 8c9c8c2 b71ac5f 8c9c8c2 d53bb63 8c9c8c2 b71ac5f 8c9c8c2 b71ac5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# import gradio as gr
# def greet(name):
# return "Hello " + name + "!!"
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer
import torch
import gradio as gr
# 验证 SentencePiece 是否安装
try:
import sentencepiece
print("SentencePiece is installed successfully!")
except ImportError:
print("SentencePiece is NOT installed!")
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 加载 Florence-2 模型和处理器
print("Loading Florence-2 model...")
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
print("Florence-2 model loaded successfully.")
# 加载 Helsinki-NLP 的翻译模型(英文到中文)
print("Loading translation model...")
translation_model_name = "Helsinki-NLP/opus-mt-en-zh"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
print("Translation model loaded successfully.")
# 翻译函数
def translate_to_chinese(text):
try:
# 确保输入是字符串
if not isinstance(text, str):
print(f"Input is not a string: {text} (type: {type(text)})")
text = str(text) # 强制转换为字符串
print("Input text for translation:", text)
tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
translated_tokens = translation_model.generate(**tokenized_text)
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
print("Translated text:", translated_text)
return translated_text
except Exception as e:
print("Translation error:", str(e))
return f"Translation error: {str(e)}"
# 生成描述并翻译
def generate_caption(image):
try:
# 如果输入是 URL,下载图片
if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
print("Downloading image from URL...")
try:
response = requests.get(image, stream=True, timeout=10)
response.raise_for_status() # 检查请求是否成功
image = Image.open(response.raw)
print("Image downloaded successfully.")
except requests.exceptions.RequestException as e:
return f"Failed to download image: {str(e)}"
# 如果输入是文件路径,直接打开图片
else:
print("Loading image from file...")
image = Image.open(image)
# 准备输入
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
# 生成文本
print("Generating caption...")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
print("Generated text:", generated_text)
print("Type of generated text:", type(generated_text))
# 解析生成的文本
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
print("Parsed answer:", parsed_answer)
print("Type of parsed answer:", type(parsed_answer))
# 翻译成中文
print("Translating to Chinese...")
translated_answer = translate_to_chinese(parsed_answer)
print("Translation completed.")
return translated_answer
except Exception as e:
print("Error:", str(e))
return f"Error: {str(e)}"
# Gradio 界面
def gradio_interface(image):
result = generate_caption(image)
return result
# 创建 Gradio 应用
iface = gr.Interface(
fn=gradio_interface, # 处理函数
inputs=gr.Image(label="Upload Image or Enter Image URL", type="filepath"), # 输入组件
outputs=gr.Textbox(label="Generated Caption (Translated to Chinese)"), # 输出组件
title="紫外线", # 标题
description="Generate detailed captions for images using Florence-2 model and translate them to Chinese. You can upload an image or provide an image URL.", # 描述
examples=[
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"]
] # 示例
)
# 启动 Gradio 应用
print("Launching Gradio app...")
iface.launch() |