Spaces:
Running
Running
import gradio as gr | |
import requests | |
import base64 | |
import os | |
# argparse 的简单替代,如有需要可替换为 argparse | |
from argparse import Namespace | |
# 假设 libra_eval 在你的 python 包 libra.eval 中 | |
from libra.eval import libra_eval | |
# 预定义图像及其链接(或者本地路径) | |
DEFAULT_IMAGES = { | |
"Image 1": "https://drive.google.com/uc?export=view&id=10bvR7a4WSyDAtWsNQUjPSs1GlcSxtP81", | |
"Image 2": "https://drive.google.com/uc?export=view&id=1yzKM1eo8yBAGRcm7ayqUhxASXQHNANUa" | |
} | |
############################################################################### | |
# 如果需要直接使用本地文件,可将以上链接替换为本地路径,比如: | |
# DEFAULT_IMAGES = { | |
# "Image 1": "/path/to/local/file1.jpg", | |
# "Image 2": "/path/to/local/file2.jpg" | |
# } | |
############################################################################### | |
def image_url_to_base64(image_url: str) -> str: | |
""" | |
将远程图片 URL 转换为 Base64 数据 URI。 | |
如果请求失败,则返回提示文本。 | |
""" | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() | |
base64_image = base64.b64encode(response.content).decode("utf-8") | |
return f"data:image/jpeg;base64,{base64_image}" | |
except Exception as e: | |
return f"<p style='color: red;'>Failed to load image: {e}</p>" | |
def generate_image_html(image_url: str) -> str: | |
""" | |
生成一个 <img> 标签的 HTML,用于在 Gradio 中以预览形式显示图片。 | |
如果是 http(s) 链接,则尝试转换为 Base64;如果是本地路径,直接使用 file://。 | |
""" | |
# 判断是否以 http(s) 开头 | |
if image_url.startswith("http"): | |
base64_image = image_url_to_base64(image_url) | |
return f'<img src="{base64_image}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />' | |
else: | |
# 直接使用本地路径 | |
return f'<img src="file://{image_url}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />' | |
def generate_radiology_description( | |
prompt: str, | |
selected_current: str, | |
uploaded_current: str, | |
selected_prior: str, | |
uploaded_prior: str, | |
temperature: float, | |
top_p: float, | |
num_beams: int, | |
max_new_tokens: int | |
) -> str: | |
""" | |
核心推理函数: | |
1. 获取用户输入或默认图片 | |
2. 调用 libra_eval 来生成报告描述 | |
3. 返回生成的结果或错误消息 | |
""" | |
# 如果用户上传了图片,则优先使用上传的图片;否则使用默认图片 | |
current_image = uploaded_current if uploaded_current else DEFAULT_IMAGES.get(selected_current) | |
prior_image = uploaded_prior if uploaded_prior else DEFAULT_IMAGES.get(selected_prior) | |
# 确保用户选择或上传了两张图片 | |
if not current_image or not prior_image: | |
return "Please select or upload both current and prior images." | |
# 模型路径(示例) | |
model_path = "/nfs/LLaVA-ai4bio/gla-biomed-playground/final_model/finetuned_model/llava-libra-test" | |
conv_mode = "libra_v1" | |
try: | |
# 调用 libra_eval 进行推理 | |
output = libra_eval( | |
model_path=model_path, | |
model_base=None, | |
image_file=[current_image, prior_image], | |
query=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
num_beams=num_beams, | |
length_penalty=1.0, | |
num_return_sequences=1, | |
conv_mode=conv_mode, | |
max_new_tokens=max_new_tokens | |
) | |
return output | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# 在 Gradio 中构建 UI | |
# Blocks 为最新的容器API,可以更好地对布局进行控制 | |
with gr.Blocks() as demo: | |
# 标题和简单说明 | |
gr.Markdown("# Libra Radiology Report Generator") | |
gr.Markdown("Use **Libra** to generate radiology image descriptions. Provide a **Current** and a **Prior** image below.") | |
# 用户输入的文本 | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
value="Provide a detailed description of the findings in the radiology image." | |
) | |
# 当前图像(Current Image)和历史对比图像(Prior Image) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Current Image") | |
# 预览默认图像 | |
for img in DEFAULT_IMAGES.values(): | |
gr.HTML(generate_image_html(img)) | |
# 在Radio中选择 | |
selected_current = gr.Radio( | |
label="Select Current Image", | |
choices=list(DEFAULT_IMAGES.keys()), | |
value="Image 1" | |
) | |
# 或者上传一张新的 | |
uploaded_current = gr.Image( | |
label="Or Upload Current Image", | |
type="filepath", | |
tool="editor" | |
) | |
with gr.Column(): | |
gr.Markdown("### Prior Image") | |
# 同样显示默认图像 | |
for img in DEFAULT_IMAGES.values(): | |
gr.HTML(generate_image_html(img)) | |
selected_prior = gr.Radio( | |
label="Select Prior Image", | |
choices=list(DEFAULT_IMAGES.keys()), | |
value="Image 2" | |
) | |
uploaded_prior = gr.Image( | |
label="Or Upload Prior Image", | |
type="filepath", | |
tool="editor" | |
) | |
# 一些可调参数 | |
with gr.Row(): | |
temperature_slider = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.7 | |
) | |
top_p_slider = gr.Slider( | |
label="Top P", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.8 | |
) | |
num_beams_slider = gr.Slider( | |
label="Number of Beams", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=2 | |
) | |
max_tokens_slider = gr.Slider( | |
label="Max New Tokens", | |
minimum=10, | |
maximum=4096, | |
step=10, | |
value=128 | |
) | |
# 用于显示模型生成的结果 | |
output_text = gr.Textbox( | |
label="Generated Description", | |
lines=10 | |
) | |
# 点击按钮时触发的推理逻辑 | |
generate_button = gr.Button("Generate Description") | |
generate_button.click( | |
fn=generate_radiology_description, | |
inputs=[ | |
prompt_input, | |
selected_current, | |
uploaded_current, | |
selected_prior, | |
uploaded_prior, | |
temperature_slider, | |
top_p_slider, | |
num_beams_slider, | |
max_tokens_slider | |
], | |
outputs=output_text | |
) | |
# 启动 Gradio 应用(将 share 设置为 True 以便在 Hugging Face Spaces 中分享) | |
if __name__ == "__main__": | |
demo.launch(share=True) |