File size: 5,029 Bytes
5cf2b53
68c33ac
5cf2b53
 
68c33ac
5cf2b53
 
 
 
8c9c8c2
5cf2b53
b71ac5f
5cf2b53
8c9c8c2
 
 
 
 
 
 
5cf2b53
 
8c9c8c2
5cf2b53
8c9c8c2
 
5cf2b53
 
8c9c8c2
 
 
 
 
 
 
 
8f11c9f
8c9c8c2
 
b71ac5f
8c9c8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ded79ec
8c9c8c2
 
 
 
 
 
 
 
 
 
 
 
b71ac5f
 
 
 
 
 
8c9c8c2
b71ac5f
 
 
 
 
 
 
 
8c9c8c2
 
b71ac5f
 
 
8c9c8c2
 
 
 
 
 
 
8f11c9f
8c9c8c2
b71ac5f
8c9c8c2
b71ac5f
 
 
8c9c8c2
 
b71ac5f
 
 
 
 
8c9c8c2
 
d53bb63
8c9c8c2
b71ac5f
 
 
 
 
 
8c9c8c2
b71ac5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# import gradio as gr

# def greet(name):
#     return "Hello " + name + "!!"

# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer
import torch
import gradio as gr

# 验证 SentencePiece 是否安装
try:
    import sentencepiece
    print("SentencePiece is installed successfully!")
except ImportError:
    print("SentencePiece is NOT installed!")

# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 加载 Florence-2 模型和处理器
print("Loading Florence-2 model...")
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
print("Florence-2 model loaded successfully.")

# 加载 Helsinki-NLP 的翻译模型(英文到中文)
print("Loading translation model...")
translation_model_name = "Helsinki-NLP/opus-mt-en-zh"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
print("Translation model loaded successfully.")

# 翻译函数
def translate_to_chinese(text):
    try:
        # 确保输入是字符串
        if not isinstance(text, str):
            print(f"Input is not a string: {text} (type: {type(text)})")
            text = str(text)  # 强制转换为字符串
        
        print("Input text for translation:", text)
        tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
        translated_tokens = translation_model.generate(**tokenized_text)
        translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
        print("Translated text:", translated_text)
        return translated_text
    except Exception as e:
        print("Translation error:", str(e))
        return f"Translation error: {str(e)}"

# 生成描述并翻译
def generate_caption(image):
    try:
        # 如果输入是 URL,下载图片
        if isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
            print("Downloading image from URL...")
            try:
                response = requests.get(image, stream=True, timeout=10)
                response.raise_for_status()  # 检查请求是否成功
                image = Image.open(response.raw)
                print("Image downloaded successfully.")
            except requests.exceptions.RequestException as e:
                return f"Failed to download image: {str(e)}"
        # 如果输入是文件路径,直接打开图片
        else:
            print("Loading image from file...")
            image = Image.open(image)
        
        # 准备输入
        prompt = "<MORE_DETAILED_CAPTION>"
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
        
        # 生成文本
        print("Generating caption...")
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            do_sample=False,
            num_beams=3
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        print("Generated text:", generated_text)
        print("Type of generated text:", type(generated_text))
        
        # 解析生成的文本
        parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
        print("Parsed answer:", parsed_answer)
        print("Type of parsed answer:", type(parsed_answer))
        
        # 翻译成中文
        print("Translating to Chinese...")
        translated_answer = translate_to_chinese(parsed_answer)
        print("Translation completed.")
        
        return translated_answer
    except Exception as e:
        print("Error:", str(e))
        return f"Error: {str(e)}"

# Gradio 界面
def gradio_interface(image):
    result = generate_caption(image)
    return result

# 创建 Gradio 应用
iface = gr.Interface(
    fn=gradio_interface,  # 处理函数
    inputs=gr.Image(label="Upload Image or Enter Image URL", type="filepath"),  # 输入组件
    outputs=gr.Textbox(label="Generated Caption (Translated to Chinese)"),  # 输出组件
    title="紫外线",  # 标题
    description="Generate detailed captions for images using Florence-2 model and translate them to Chinese. You can upload an image or provide an image URL.",  # 描述
    examples=[
        ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"]
    ]  # 示例
)

# 启动 Gradio 应用
print("Launching Gradio app...")
iface.launch()