Libra / app.py
X-iZhang's picture
Update app.py
3f3122f verified
raw
history blame
4.27 kB
# app.py
import torch
import gradio as gr
import os
import requests
import base64
# 假设 libra_eval 在你的 python 包 libra.eval 中
from libra.eval import libra_eval
model_path = "X-iZhang/libra-v1.0-7b"
image_files = ["./examples/curent.jpg",
"./examples/prior.jpg"]
prompt = "Provide a detailed description of the findings in the radiology image."
conv_mode = "libra_v1"
result = libra_eval(
model_path=model_path,
model_base=None,
image_file=image_files,
query=prompt,
temperature=0.9,
top_p=0.8,
max_new_tokens=512
)
print(result)
# def generate_radiology_description(
# prompt: str,
# uploaded_current: str,
# uploaded_prior: str,
# temperature: float,
# top_p: float,
# num_beams: int,
# max_new_tokens: int
# ) -> str:
# """
# 核心推理函数:
# 1. 仅通过用户上传的图片获取图像文件路径
# 2. 调用 libra_eval 来生成报告描述
# 3. 返回生成的结果或错误消息
# """
# # 确保用户上传了两张图片
# if not uploaded_current or not uploaded_prior:
# return "Please upload both current and prior images."
# # 模型路径
# model_path = "X-iZhang/libra-v1.0-7b"
# conv_mode = "libra_v1"
# try:
# # 调用 libra_eval 进行推理
# print("Before calling libra_eval")
# output = libra_eval(
# model_path=model_path,
# model_base=None, # 如果有必要,可指定基础模型
# image_file=[uploaded_current, uploaded_prior], # 两张本地图片路径
# query=prompt,
# temperature=temperature,
# top_p=top_p,
# num_beams=num_beams,
# length_penalty=1.0,
# num_return_sequences=1,
# conv_mode=conv_mode,
# max_new_tokens=max_new_tokens
# )
# print("After calling libra_eval, result:", output)
# return output
# except Exception as e:
# return f"An error occurred: {str(e)}"
# # 构建 Gradio 界面
# with gr.Blocks() as demo:
# # 标题和简单说明
# gr.Markdown("# Libra Radiology Report Generator (Local Upload Only)")
# gr.Markdown("Upload **Current** and **Prior** images below to generate a radiology description using the Libra model.")
# # 用户输入:文本提示
# prompt_input = gr.Textbox(
# label="Prompt",
# value="Describe the key findings in these two images."
# )
# # 上传本地图像(Current & Prior)
# with gr.Row():
# uploaded_current = gr.Image(
# label="Upload Current Image",
# type="filepath"
# )
# uploaded_prior = gr.Image(
# label="Upload Prior Image",
# type="filepath"
# )
# # 参数调节
# with gr.Row():
# temperature_slider = gr.Slider(
# label="Temperature",
# minimum=0.1,
# maximum=1.0,
# step=0.1,
# value=0.7
# )
# top_p_slider = gr.Slider(
# label="Top P",
# minimum=0.1,
# maximum=1.0,
# step=0.1,
# value=0.8
# )
# num_beams_slider = gr.Slider(
# label="Number of Beams",
# minimum=1,
# maximum=20,
# step=1,
# value=2
# )
# max_tokens_slider = gr.Slider(
# label="Max New Tokens",
# minimum=10,
# maximum=4096,
# step=10,
# value=128
# )
# # 用于显示模型生成的结果
# output_text = gr.Textbox(
# label="Generated Description",
# lines=10
# )
# # 点击按钮时触发的推理逻辑
# generate_button = gr.Button("Generate Description")
# generate_button.click(
# fn=generate_radiology_description,
# inputs=[
# prompt_input,
# uploaded_current,
# uploaded_prior,
# temperature_slider,
# top_p_slider,
# num_beams_slider,
# max_tokens_slider
# ],
# outputs=output_text
# )
# if __name__ == "__main__":
# demo.launch()