File size: 7,053 Bytes
869a62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
import requests
import base64
import os

# argparse 的简单替代,如有需要可替换为 argparse
from argparse import Namespace

# 假设 libra_eval 在你的 python 包 libra.eval 中
from libra.eval import libra_eval

# 预定义图像及其链接(或者本地路径)
DEFAULT_IMAGES = {
    "Image 1": "https://drive.google.com/uc?export=view&id=10bvR7a4WSyDAtWsNQUjPSs1GlcSxtP81",
    "Image 2": "https://drive.google.com/uc?export=view&id=1yzKM1eo8yBAGRcm7ayqUhxASXQHNANUa"
}

###############################################################################
# 如果需要直接使用本地文件,可将以上链接替换为本地路径,比如:
# DEFAULT_IMAGES = {
#     "Image 1": "/path/to/local/file1.jpg",
#     "Image 2": "/path/to/local/file2.jpg"
# }
###############################################################################

def image_url_to_base64(image_url: str) -> str:
    """
    将远程图片 URL 转换为 Base64 数据 URI。
    如果请求失败,则返回提示文本。
    """
    try:
        response = requests.get(image_url)
        response.raise_for_status()
        base64_image = base64.b64encode(response.content).decode("utf-8")
        return f"data:image/jpeg;base64,{base64_image}"
    except Exception as e:
        return f"<p style='color: red;'>Failed to load image: {e}</p>"

def generate_image_html(image_url: str) -> str:
    """
    生成一个 <img> 标签的 HTML,用于在 Gradio 中以预览形式显示图片。
    如果是 http(s) 链接,则尝试转换为 Base64;如果是本地路径,直接使用 file://。
    """
    # 判断是否以 http(s) 开头
    if image_url.startswith("http"):
        base64_image = image_url_to_base64(image_url)
        return f'<img src="{base64_image}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />'
    else:
        # 直接使用本地路径
        return f'<img src="file://{image_url}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />'

def generate_radiology_description(
    prompt: str,
    selected_current: str,
    uploaded_current: str,
    selected_prior: str,
    uploaded_prior: str,
    temperature: float,
    top_p: float,
    num_beams: int,
    max_new_tokens: int
) -> str:
    """
    核心推理函数:
    1. 获取用户输入或默认图片
    2. 调用 libra_eval 来生成报告描述
    3. 返回生成的结果或错误消息
    """
    # 如果用户上传了图片,则优先使用上传的图片;否则使用默认图片
    current_image = uploaded_current if uploaded_current else DEFAULT_IMAGES.get(selected_current)
    prior_image = uploaded_prior if uploaded_prior else DEFAULT_IMAGES.get(selected_prior)

    # 确保用户选择或上传了两张图片
    if not current_image or not prior_image:
        return "Please select or upload both current and prior images."

    # 模型路径(示例)
    model_path = "/nfs/LLaVA-ai4bio/gla-biomed-playground/final_model/finetuned_model/llava-libra-test"
    conv_mode = "libra_v1"

    try:
        # 调用 libra_eval 进行推理
        output = libra_eval(
            model_path=model_path,
            model_base=None,
            image_file=[current_image, prior_image],
            query=prompt,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            length_penalty=1.0,
            num_return_sequences=1,
            conv_mode=conv_mode,
            max_new_tokens=max_new_tokens
        )
        return output
    except Exception as e:
        return f"An error occurred: {str(e)}"

# 在 Gradio 中构建 UI
# Blocks 为最新的容器API,可以更好地对布局进行控制
with gr.Blocks() as demo:
    # 标题和简单说明
    gr.Markdown("# Libra Radiology Report Generator")
    gr.Markdown("Use **Libra** to generate radiology image descriptions. Provide a **Current** and a **Prior** image below.")

    # 用户输入的文本
    with gr.Row():
        prompt_input = gr.Textbox(
            label="Prompt", 
            value="Provide a detailed description of the findings in the radiology image."
        )

    # 当前图像(Current Image)和历史对比图像(Prior Image)
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Current Image")
            # 预览默认图像
            for img in DEFAULT_IMAGES.values():
                gr.HTML(generate_image_html(img))
            # 在Radio中选择
            selected_current = gr.Radio(
                label="Select Current Image", 
                choices=list(DEFAULT_IMAGES.keys()), 
                value="Image 1"
            )
            # 或者上传一张新的
            uploaded_current = gr.Image(
                label="Or Upload Current Image", 
                type="filepath", 
                tool="editor"
            )

        with gr.Column():
            gr.Markdown("### Prior Image")
            # 同样显示默认图像
            for img in DEFAULT_IMAGES.values():
                gr.HTML(generate_image_html(img))
            selected_prior = gr.Radio(
                label="Select Prior Image", 
                choices=list(DEFAULT_IMAGES.keys()), 
                value="Image 2"
            )
            uploaded_prior = gr.Image(
                label="Or Upload Prior Image", 
                type="filepath", 
                tool="editor"
            )

    # 一些可调参数
    with gr.Row():
        temperature_slider = gr.Slider(
            label="Temperature", 
            minimum=0.1, 
            maximum=1.0, 
            step=0.1, 
            value=0.7
        )
        top_p_slider = gr.Slider(
            label="Top P", 
            minimum=0.1, 
            maximum=1.0, 
            step=0.1, 
            value=0.8
        )
        num_beams_slider = gr.Slider(
            label="Number of Beams", 
            minimum=1, 
            maximum=20, 
            step=1, 
            value=2
        )
        max_tokens_slider = gr.Slider(
            label="Max New Tokens", 
            minimum=10, 
            maximum=4096, 
            step=10, 
            value=128
        )

    # 用于显示模型生成的结果
    output_text = gr.Textbox(
        label="Generated Description",
        lines=10
    )

    # 点击按钮时触发的推理逻辑
    generate_button = gr.Button("Generate Description")
    generate_button.click(
        fn=generate_radiology_description,
        inputs=[
            prompt_input, 
            selected_current, 
            uploaded_current,
            selected_prior, 
            uploaded_prior,
            temperature_slider, 
            top_p_slider, 
            num_beams_slider, 
            max_tokens_slider
        ],
        outputs=output_text
    )

# 启动 Gradio 应用(将 share 设置为 True 以便在 Hugging Face Spaces 中分享)
if __name__ == "__main__":
    demo.launch(share=True)