ziwaixian009 commited on
Commit
8c9c8c2
·
verified ·
1 Parent(s): afa88bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -11
app.py CHANGED
@@ -7,27 +7,76 @@
7
  # demo.launch()
8
  import requests
9
  from PIL import Image
10
- from transformers import AutoModelForCausalLM, AutoProcessor
11
  import torch
12
  import gradio as gr
13
 
 
 
 
 
 
 
 
14
  # 设置设备
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
 
17
- # 加载模型和处理器
 
18
  model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
19
  processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
 
 
 
 
 
 
 
 
20
 
21
- def generate_caption(image_url):
 
22
  try:
23
- # 下载并打开图像
24
- image = Image.open(requests.get(image_url, stream=True).raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # 准备输入
27
  prompt = "<MORE_DETAILED_CAPTION>"
28
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
29
 
30
  # 生成文本
 
31
  generated_ids = model.generate(
32
  input_ids=inputs["input_ids"],
33
  pixel_values=inputs["pixel_values"],
@@ -36,30 +85,41 @@ def generate_caption(image_url):
36
  num_beams=3
37
  )
38
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
39
 
40
  # 解析生成的文本
41
  parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
 
 
 
 
 
 
 
42
 
43
- return parsed_answer
44
  except Exception as e:
 
45
  return f"Error: {str(e)}"
46
 
47
  # Gradio 界面
48
- def gradio_interface(image_url):
49
- result = generate_caption(image_url)
50
  return result
51
 
52
  # 创建 Gradio 应用
53
  iface = gr.Interface(
54
  fn=gradio_interface, # 处理函数
55
- inputs=gr.Textbox(label="Image URL", placeholder="Enter the URL of the image..."), # 输入组件
56
- outputs=gr.Textbox(label="Generated Caption"), # 输出组件
57
  title="Florence-2 Prompt Generation", # 标题
58
- description="Generate detailed captions for images using Florence-2 model.", # 描述
59
  examples=[
60
  ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"]
61
  ] # 示例
62
  )
63
 
64
  # 启动 Gradio 应用
 
65
  iface.launch()
 
7
  # demo.launch()
8
  import requests
9
  from PIL import Image
10
+ from transformers import AutoModelForCausalLM, AutoProcessor, MarianMTModel, MarianTokenizer
11
  import torch
12
  import gradio as gr
13
 
14
+ # 验证 SentencePiece 是否安装
15
+ try:
16
+ import sentencepiece
17
+ print("SentencePiece is installed successfully!")
18
+ except ImportError:
19
+ print("SentencePiece is NOT installed!")
20
+
21
  # 设置设备
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ print(f"Using device: {device}")
24
 
25
+ # 加载 Florence-2 模型和处理器
26
+ print("Loading Florence-2 model...")
27
  model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
28
  processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
29
+ print("Florence-2 model loaded successfully.")
30
+
31
+ # 加载 Helsinki-NLP 的翻译模型(英文到中文)
32
+ print("Loading translation model...")
33
+ translation_model_name = "Helsinki-NLP/opus-mt-en-zh"
34
+ translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
35
+ translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
36
+ print("Translation model loaded successfully.")
37
 
38
+ # 翻译函数
39
+ def translate_to_chinese(text):
40
  try:
41
+ # 确保输入是字符串
42
+ if not isinstance(text, str):
43
+ print(f"Input is not a string: {text} (type: {type(text)})")
44
+ text = str(text) # 强制转换为字符串
45
+
46
+ print("Input text for translation:", text)
47
+ tokenized_text = translation_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
48
+ translated_tokens = translation_model.generate(**tokenized_text)
49
+ translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
50
+ print("Translated text:", translated_text)
51
+ return translated_text
52
+ except Exception as e:
53
+ print("Translation error:", str(e))
54
+ return f"Translation error: {str(e)}"
55
+
56
+ # 生成描述并翻译
57
+ def generate_caption(image):
58
+ try:
59
+ # 如果输入是 URL,下载图片
60
+ if isinstance(image, str):
61
+ print("Downloading image from URL...")
62
+ try:
63
+ response = requests.get(image, stream=True, timeout=10)
64
+ response.raise_for_status() # 检查请求是否成功
65
+ image = Image.open(response.raw)
66
+ print("Image downloaded successfully.")
67
+ except requests.exceptions.RequestException as e:
68
+ return f"Failed to download image: {str(e)}"
69
+ # 如果输入是文件路径,直接打开图片
70
+ else:
71
+ print("Loading image from file...")
72
+ image = Image.open(image)
73
 
74
  # 准备输入
75
  prompt = "<MORE_DETAILED_CAPTION>"
76
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
77
 
78
  # 生成文本
79
+ print("Generating caption...")
80
  generated_ids = model.generate(
81
  input_ids=inputs["input_ids"],
82
  pixel_values=inputs["pixel_values"],
 
85
  num_beams=3
86
  )
87
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
88
+ print("Generated text:", generated_text)
89
+ print("Type of generated text:", type(generated_text))
90
 
91
  # 解析生成的文本
92
  parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
93
+ print("Parsed answer:", parsed_answer)
94
+ print("Type of parsed answer:", type(parsed_answer))
95
+
96
+ # 翻译成中文
97
+ print("Translating to Chinese...")
98
+ translated_answer = translate_to_chinese(parsed_answer)
99
+ print("Translation completed.")
100
 
101
+ return translated_answer
102
  except Exception as e:
103
+ print("Error:", str(e))
104
  return f"Error: {str(e)}"
105
 
106
  # Gradio 界面
107
+ def gradio_interface(image):
108
+ result = generate_caption(image)
109
  return result
110
 
111
  # 创建 Gradio 应用
112
  iface = gr.Interface(
113
  fn=gradio_interface, # 处理函数
114
+ inputs=gr.Image(label="Upload Image or Enter Image URL", type="filepath"), # 输入组件
115
+ outputs=gr.Textbox(label="Generated Caption (Translated to Chinese)"), # 输出组件
116
  title="Florence-2 Prompt Generation", # 标题
117
+ description="Generate detailed captions for images using Florence-2 model and translate them to Chinese. You can upload an image or provide an image URL.", # 描述
118
  examples=[
119
  ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"]
120
  ] # 示例
121
  )
122
 
123
  # 启动 Gradio 应用
124
+ print("Launching Gradio app...")
125
  iface.launch()