|
import base64 |
|
import os |
|
|
|
import gradio as gr |
|
from openai import OpenAI |
|
|
|
client = OpenAI( |
|
api_key=os.getenv('HUNYUAN_API_KEY'), |
|
base_url="https://api.hunyuan.cloud.tencent.com/v1" |
|
) |
|
|
|
def generate_caption(image_path, question): |
|
|
|
with open(image_path, "rb") as image_file: |
|
base64_image = base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
messages = [{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": question}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{base64_image}" |
|
} |
|
} |
|
] |
|
}] |
|
|
|
|
|
response = client.chat.completions.create( |
|
model="hunyuan-vision", |
|
messages=messages, |
|
stream=True, |
|
extra_body={ |
|
"stream_moderation": True, |
|
"enable_enhancement": False |
|
} |
|
) |
|
|
|
|
|
full_response = "" |
|
for chunk in response: |
|
token = chunk.choices[0].delta.content |
|
if token: |
|
full_response += token |
|
yield full_response |
|
|
|
|
|
|
|
|
|
title = "Hunyuan-Vision图生文Demo" |
|
theme = gr.themes.Soft( |
|
primary_hue="teal", |
|
secondary_hue="blue", |
|
font=[gr.themes.GoogleFont("Noto Sans SC"), "Arial", "sans-serif"] |
|
) |
|
|
|
with gr.Blocks(title=title, theme=theme) as demo: |
|
|
|
gr.Markdown(f""" |
|
<div style="text-align: center;"> |
|
<h1 style="color: #2E86C1; border-bottom: 3px solid #AED6F1; padding-bottom: 10px;">🖼️ {title}</h1> |
|
<p style="color: #616A6B;">上传图片并输入问题,体验腾讯混元视觉大模型的图像理解能力</p> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
|
|
with gr.Column(scale=3): |
|
with gr.Group(label="输入区域"): |
|
image_input = gr.Image( |
|
type="filepath", |
|
label="上传图片", |
|
height=400, |
|
show_download_button=False, |
|
elem_classes="preview-box" |
|
) |
|
question_input = gr.Textbox( |
|
label="问题描述", |
|
placeholder="请输入关于图片的问题...", |
|
value="请详细描述图片中的场景、人物和细节", |
|
lines=2 |
|
) |
|
with gr.Row(): |
|
clear_btn = gr.Button("清空", variant="secondary") |
|
submit_btn = gr.Button("生成描述", variant="primary") |
|
|
|
|
|
with gr.Column(scale=4): |
|
with gr.Group(label="生成结果"): |
|
output = gr.Textbox( |
|
label="描述内容", |
|
interactive=False, |
|
show_copy_button=True, |
|
lines=12, |
|
max_lines=20, |
|
autoscroll=True |
|
) |
|
|
|
|
|
with gr.Accordion("🖼️ 点击查看示例", open=False): |
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
["tencent.png", "图片中的天气状况如何?"], |
|
["tencent.png", "描述参会人员的衣着特征"] |
|
], |
|
inputs=[image_input, question_input], |
|
label="快速示例" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_caption, |
|
inputs=[image_input, question_input], |
|
outputs=output, |
|
api_name="generate" |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: [None, "", ""], |
|
outputs=[image_input, question_input, output], |
|
queue=False |
|
) |
|
|
|
|
|
css = """ |
|
.preview-box img {border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);} |
|
.preview-box:hover img {transform: scale(1.02);} |
|
button#generate {transition: all 0.3s ease;} |
|
""" |
|
demo.css = css |
|
|
|
if __name__ == "__main__": |
|
demo.queue(default_concurrency_limit=100) |
|
demo.launch( |
|
server_port=7860, |
|
show_error=True, |
|
favicon_path="favicon.ico", |
|
max_threads=100 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|