File size: 3,632 Bytes
8e07496
 
869a62d
0aeb285
869a62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aeb285
869a62d
 
 
 
0aeb285
 
 
869a62d
0aeb285
 
869a62d
 
 
 
c1e5f6a
869a62d
 
0aeb285
 
869a62d
 
 
 
 
 
 
8e07496
869a62d
c1e5f6a
869a62d
 
 
 
0aeb285
869a62d
 
0aeb285
 
869a62d
0aeb285
 
 
 
 
869a62d
0aeb285
869a62d
0aeb285
 
 
 
 
 
 
 
869a62d
0aeb285
869a62d
 
0aeb285
 
 
 
869a62d
 
 
0aeb285
 
 
 
869a62d
 
 
0aeb285
 
 
 
869a62d
 
 
0aeb285
 
 
 
869a62d
 
 
 
 
 
 
 
 
 
 
 
 
 
0aeb285
869a62d
 
0aeb285
 
 
869a62d
 
 
 
 
 
f57f94b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# app.py
import torch
import gradio as gr
import os
import requests
import base64

# 假设 libra_eval 在你的 python 包 libra.eval 中
from libra.eval import libra_eval

def generate_radiology_description(
    prompt: str,
    uploaded_current: str,
    uploaded_prior: str,
    temperature: float,
    top_p: float,
    num_beams: int,
    max_new_tokens: int
) -> str:
    """
    核心推理函数:
    1. 仅通过用户上传的图片获取图像文件路径
    2. 调用 libra_eval 来生成报告描述
    3. 返回生成的结果或错误消息
    """

    # 确保用户上传了两张图片
    if not uploaded_current or not uploaded_prior:
        return "Please upload both current and prior images."

    # 模型路径
    model_path = "X-iZhang/libra-v1.0-7b"
    conv_mode = "libra_v1"

    try:
        # 调用 libra_eval 进行推理
        print("Before calling libra_eval")
        output = libra_eval(
            model_path=model_path,
            model_base=None,  # 如果有必要,可指定基础模型
            image_file=[uploaded_current, uploaded_prior],  # 两张本地图片路径
            query=prompt,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            length_penalty=1.0,
            num_return_sequences=1,
            conv_mode=conv_mode,
            max_new_tokens=max_new_tokens
        )
        print("After calling libra_eval, result:", output)
        return output
    except Exception as e:
        return f"An error occurred: {str(e)}"

# 构建 Gradio 界面
with gr.Blocks() as demo:
    # 标题和简单说明
    gr.Markdown("# Libra Radiology Report Generator (Local Upload Only)")
    gr.Markdown("Upload **Current** and **Prior** images below to generate a radiology description using the Libra model.")

    # 用户输入:文本提示
    prompt_input = gr.Textbox(
        label="Prompt",
        value="Describe the key findings in these two images."
    )

    # 上传本地图像(Current & Prior)
    with gr.Row():
        uploaded_current = gr.Image(
            label="Upload Current Image",
            type="filepath"
        )
        uploaded_prior = gr.Image(
            label="Upload Prior Image",
            type="filepath"
        )

    # 参数调节
    with gr.Row():
        temperature_slider = gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=1.0,
            step=0.1,
            value=0.7
        )
        top_p_slider = gr.Slider(
            label="Top P",
            minimum=0.1,
            maximum=1.0,
            step=0.1,
            value=0.8
        )
        num_beams_slider = gr.Slider(
            label="Number of Beams",
            minimum=1,
            maximum=20,
            step=1,
            value=2
        )
        max_tokens_slider = gr.Slider(
            label="Max New Tokens",
            minimum=10,
            maximum=4096,
            step=10,
            value=128
        )

    # 用于显示模型生成的结果
    output_text = gr.Textbox(
        label="Generated Description",
        lines=10
    )

    # 点击按钮时触发的推理逻辑
    generate_button = gr.Button("Generate Description")
    generate_button.click(
        fn=generate_radiology_description,
        inputs=[
            prompt_input,
            uploaded_current,
            uploaded_prior,
            temperature_slider,
            top_p_slider,
            num_beams_slider,
            max_tokens_slider
        ],
        outputs=output_text
    )

if __name__ == "__main__":
    demo.launch()