Spaces:
Running
Running
File size: 7,053 Bytes
869a62d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import gradio as gr
import requests
import base64
import os
# argparse 的简单替代,如有需要可替换为 argparse
from argparse import Namespace
# 假设 libra_eval 在你的 python 包 libra.eval 中
from libra.eval import libra_eval
# 预定义图像及其链接(或者本地路径)
DEFAULT_IMAGES = {
"Image 1": "https://drive.google.com/uc?export=view&id=10bvR7a4WSyDAtWsNQUjPSs1GlcSxtP81",
"Image 2": "https://drive.google.com/uc?export=view&id=1yzKM1eo8yBAGRcm7ayqUhxASXQHNANUa"
}
###############################################################################
# 如果需要直接使用本地文件,可将以上链接替换为本地路径,比如:
# DEFAULT_IMAGES = {
# "Image 1": "/path/to/local/file1.jpg",
# "Image 2": "/path/to/local/file2.jpg"
# }
###############################################################################
def image_url_to_base64(image_url: str) -> str:
"""
将远程图片 URL 转换为 Base64 数据 URI。
如果请求失败,则返回提示文本。
"""
try:
response = requests.get(image_url)
response.raise_for_status()
base64_image = base64.b64encode(response.content).decode("utf-8")
return f"data:image/jpeg;base64,{base64_image}"
except Exception as e:
return f"<p style='color: red;'>Failed to load image: {e}</p>"
def generate_image_html(image_url: str) -> str:
"""
生成一个 <img> 标签的 HTML,用于在 Gradio 中以预览形式显示图片。
如果是 http(s) 链接,则尝试转换为 Base64;如果是本地路径,直接使用 file://。
"""
# 判断是否以 http(s) 开头
if image_url.startswith("http"):
base64_image = image_url_to_base64(image_url)
return f'<img src="{base64_image}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />'
else:
# 直接使用本地路径
return f'<img src="file://{image_url}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />'
def generate_radiology_description(
prompt: str,
selected_current: str,
uploaded_current: str,
selected_prior: str,
uploaded_prior: str,
temperature: float,
top_p: float,
num_beams: int,
max_new_tokens: int
) -> str:
"""
核心推理函数:
1. 获取用户输入或默认图片
2. 调用 libra_eval 来生成报告描述
3. 返回生成的结果或错误消息
"""
# 如果用户上传了图片,则优先使用上传的图片;否则使用默认图片
current_image = uploaded_current if uploaded_current else DEFAULT_IMAGES.get(selected_current)
prior_image = uploaded_prior if uploaded_prior else DEFAULT_IMAGES.get(selected_prior)
# 确保用户选择或上传了两张图片
if not current_image or not prior_image:
return "Please select or upload both current and prior images."
# 模型路径(示例)
model_path = "/nfs/LLaVA-ai4bio/gla-biomed-playground/final_model/finetuned_model/llava-libra-test"
conv_mode = "libra_v1"
try:
# 调用 libra_eval 进行推理
output = libra_eval(
model_path=model_path,
model_base=None,
image_file=[current_image, prior_image],
query=prompt,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
length_penalty=1.0,
num_return_sequences=1,
conv_mode=conv_mode,
max_new_tokens=max_new_tokens
)
return output
except Exception as e:
return f"An error occurred: {str(e)}"
# 在 Gradio 中构建 UI
# Blocks 为最新的容器API,可以更好地对布局进行控制
with gr.Blocks() as demo:
# 标题和简单说明
gr.Markdown("# Libra Radiology Report Generator")
gr.Markdown("Use **Libra** to generate radiology image descriptions. Provide a **Current** and a **Prior** image below.")
# 用户输入的文本
with gr.Row():
prompt_input = gr.Textbox(
label="Prompt",
value="Provide a detailed description of the findings in the radiology image."
)
# 当前图像(Current Image)和历史对比图像(Prior Image)
with gr.Row():
with gr.Column():
gr.Markdown("### Current Image")
# 预览默认图像
for img in DEFAULT_IMAGES.values():
gr.HTML(generate_image_html(img))
# 在Radio中选择
selected_current = gr.Radio(
label="Select Current Image",
choices=list(DEFAULT_IMAGES.keys()),
value="Image 1"
)
# 或者上传一张新的
uploaded_current = gr.Image(
label="Or Upload Current Image",
type="filepath",
tool="editor"
)
with gr.Column():
gr.Markdown("### Prior Image")
# 同样显示默认图像
for img in DEFAULT_IMAGES.values():
gr.HTML(generate_image_html(img))
selected_prior = gr.Radio(
label="Select Prior Image",
choices=list(DEFAULT_IMAGES.keys()),
value="Image 2"
)
uploaded_prior = gr.Image(
label="Or Upload Prior Image",
type="filepath",
tool="editor"
)
# 一些可调参数
with gr.Row():
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.7
)
top_p_slider = gr.Slider(
label="Top P",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.8
)
num_beams_slider = gr.Slider(
label="Number of Beams",
minimum=1,
maximum=20,
step=1,
value=2
)
max_tokens_slider = gr.Slider(
label="Max New Tokens",
minimum=10,
maximum=4096,
step=10,
value=128
)
# 用于显示模型生成的结果
output_text = gr.Textbox(
label="Generated Description",
lines=10
)
# 点击按钮时触发的推理逻辑
generate_button = gr.Button("Generate Description")
generate_button.click(
fn=generate_radiology_description,
inputs=[
prompt_input,
selected_current,
uploaded_current,
selected_prior,
uploaded_prior,
temperature_slider,
top_p_slider,
num_beams_slider,
max_tokens_slider
],
outputs=output_text
)
# 启动 Gradio 应用(将 share 设置为 True 以便在 Hugging Face Spaces 中分享)
if __name__ == "__main__":
demo.launch(share=True) |