X-iZhang commited on
Commit
869a62d
·
verified ·
1 Parent(s): 1bd7c26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import base64
4
+ import os
5
+
6
+ # argparse 的简单替代,如有需要可替换为 argparse
7
+ from argparse import Namespace
8
+
9
+ # 假设 libra_eval 在你的 python 包 libra.eval 中
10
+ from libra.eval import libra_eval
11
+
12
+ # 预定义图像及其链接(或者本地路径)
13
+ DEFAULT_IMAGES = {
14
+ "Image 1": "https://drive.google.com/uc?export=view&id=10bvR7a4WSyDAtWsNQUjPSs1GlcSxtP81",
15
+ "Image 2": "https://drive.google.com/uc?export=view&id=1yzKM1eo8yBAGRcm7ayqUhxASXQHNANUa"
16
+ }
17
+
18
+ ###############################################################################
19
+ # 如果需要直接使用本地文件,可将以上链接替换为本地路径,比如:
20
+ # DEFAULT_IMAGES = {
21
+ # "Image 1": "/path/to/local/file1.jpg",
22
+ # "Image 2": "/path/to/local/file2.jpg"
23
+ # }
24
+ ###############################################################################
25
+
26
+ def image_url_to_base64(image_url: str) -> str:
27
+ """
28
+ 将远程图片 URL 转换为 Base64 数据 URI。
29
+ 如果请求失败,则返回提示文本。
30
+ """
31
+ try:
32
+ response = requests.get(image_url)
33
+ response.raise_for_status()
34
+ base64_image = base64.b64encode(response.content).decode("utf-8")
35
+ return f"data:image/jpeg;base64,{base64_image}"
36
+ except Exception as e:
37
+ return f"<p style='color: red;'>Failed to load image: {e}</p>"
38
+
39
+ def generate_image_html(image_url: str) -> str:
40
+ """
41
+ 生成一个 <img> 标签的 HTML,用于在 Gradio 中以预览形式显示图片。
42
+ 如果是 http(s) 链接,则尝试转换为 Base64;如果是本地路径,直接使用 file://。
43
+ """
44
+ # 判断是否以 http(s) 开头
45
+ if image_url.startswith("http"):
46
+ base64_image = image_url_to_base64(image_url)
47
+ return f'<img src="{base64_image}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />'
48
+ else:
49
+ # 直接使用本地路径
50
+ return f'<img src="file://{image_url}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />'
51
+
52
+ def generate_radiology_description(
53
+ prompt: str,
54
+ selected_current: str,
55
+ uploaded_current: str,
56
+ selected_prior: str,
57
+ uploaded_prior: str,
58
+ temperature: float,
59
+ top_p: float,
60
+ num_beams: int,
61
+ max_new_tokens: int
62
+ ) -> str:
63
+ """
64
+ 核心推理函数:
65
+ 1. 获取用户输入或默认图片
66
+ 2. 调用 libra_eval 来生成报告描述
67
+ 3. 返回生成的结果或错误消息
68
+ """
69
+ # 如果用户上传了图片,则优先使用上传的图片;否则使用默认图片
70
+ current_image = uploaded_current if uploaded_current else DEFAULT_IMAGES.get(selected_current)
71
+ prior_image = uploaded_prior if uploaded_prior else DEFAULT_IMAGES.get(selected_prior)
72
+
73
+ # 确保用户选择或上传了两张图片
74
+ if not current_image or not prior_image:
75
+ return "Please select or upload both current and prior images."
76
+
77
+ # 模型路径(示例)
78
+ model_path = "/nfs/LLaVA-ai4bio/gla-biomed-playground/final_model/finetuned_model/llava-libra-test"
79
+ conv_mode = "libra_v1"
80
+
81
+ try:
82
+ # 调用 libra_eval 进行推理
83
+ output = libra_eval(
84
+ model_path=model_path,
85
+ model_base=None,
86
+ image_file=[current_image, prior_image],
87
+ query=prompt,
88
+ temperature=temperature,
89
+ top_p=top_p,
90
+ num_beams=num_beams,
91
+ length_penalty=1.0,
92
+ num_return_sequences=1,
93
+ conv_mode=conv_mode,
94
+ max_new_tokens=max_new_tokens
95
+ )
96
+ return output
97
+ except Exception as e:
98
+ return f"An error occurred: {str(e)}"
99
+
100
+ # 在 Gradio 中构建 UI
101
+ # Blocks 为最新的容器API,可以更好地对布局进行控制
102
+ with gr.Blocks() as demo:
103
+ # 标题和简单说明
104
+ gr.Markdown("# Libra Radiology Report Generator")
105
+ gr.Markdown("Use **Libra** to generate radiology image descriptions. Provide a **Current** and a **Prior** image below.")
106
+
107
+ # 用户输入的文本
108
+ with gr.Row():
109
+ prompt_input = gr.Textbox(
110
+ label="Prompt",
111
+ value="Provide a detailed description of the findings in the radiology image."
112
+ )
113
+
114
+ # 当前图像(Current Image)和历史对比图像(Prior Image)
115
+ with gr.Row():
116
+ with gr.Column():
117
+ gr.Markdown("### Current Image")
118
+ # 预览默认图像
119
+ for img in DEFAULT_IMAGES.values():
120
+ gr.HTML(generate_image_html(img))
121
+ # 在Radio中选择
122
+ selected_current = gr.Radio(
123
+ label="Select Current Image",
124
+ choices=list(DEFAULT_IMAGES.keys()),
125
+ value="Image 1"
126
+ )
127
+ # 或者上传一张新的
128
+ uploaded_current = gr.Image(
129
+ label="Or Upload Current Image",
130
+ type="filepath",
131
+ tool="editor"
132
+ )
133
+
134
+ with gr.Column():
135
+ gr.Markdown("### Prior Image")
136
+ # 同样显示默认图像
137
+ for img in DEFAULT_IMAGES.values():
138
+ gr.HTML(generate_image_html(img))
139
+ selected_prior = gr.Radio(
140
+ label="Select Prior Image",
141
+ choices=list(DEFAULT_IMAGES.keys()),
142
+ value="Image 2"
143
+ )
144
+ uploaded_prior = gr.Image(
145
+ label="Or Upload Prior Image",
146
+ type="filepath",
147
+ tool="editor"
148
+ )
149
+
150
+ # 一些可调参数
151
+ with gr.Row():
152
+ temperature_slider = gr.Slider(
153
+ label="Temperature",
154
+ minimum=0.1,
155
+ maximum=1.0,
156
+ step=0.1,
157
+ value=0.7
158
+ )
159
+ top_p_slider = gr.Slider(
160
+ label="Top P",
161
+ minimum=0.1,
162
+ maximum=1.0,
163
+ step=0.1,
164
+ value=0.8
165
+ )
166
+ num_beams_slider = gr.Slider(
167
+ label="Number of Beams",
168
+ minimum=1,
169
+ maximum=20,
170
+ step=1,
171
+ value=2
172
+ )
173
+ max_tokens_slider = gr.Slider(
174
+ label="Max New Tokens",
175
+ minimum=10,
176
+ maximum=4096,
177
+ step=10,
178
+ value=128
179
+ )
180
+
181
+ # 用于显示模型生成的结果
182
+ output_text = gr.Textbox(
183
+ label="Generated Description",
184
+ lines=10
185
+ )
186
+
187
+ # 点击按钮时触发的推理逻辑
188
+ generate_button = gr.Button("Generate Description")
189
+ generate_button.click(
190
+ fn=generate_radiology_description,
191
+ inputs=[
192
+ prompt_input,
193
+ selected_current,
194
+ uploaded_current,
195
+ selected_prior,
196
+ uploaded_prior,
197
+ temperature_slider,
198
+ top_p_slider,
199
+ num_beams_slider,
200
+ max_tokens_slider
201
+ ],
202
+ outputs=output_text
203
+ )
204
+
205
+ # 启动 Gradio 应用(将 share 设置为 True 以便在 Hugging Face Spaces 中分享)
206
+ if __name__ == "__main__":
207
+ demo.launch(share=True)