File size: 1,644 Bytes
8273d5f
2a438ba
8273d5f
2a438ba
580cc25
 
695df9a
580cc25
 
 
 
 
 
 
 
a41d9f9
 
580cc25
 
 
 
a41d9f9
 
 
 
 
580cc25
 
 
 
 
 
 
2a438ba
 
8273d5f
 
695df9a
8273d5f
a41d9f9
8273d5f
 
580cc25
8273d5f
 
2a438ba
d652f80
 
 
8273d5f
a41d9f9
8273d5f
d652f80
8273d5f
695df9a
d652f80
 
 
8273d5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import gradio as gr
from transformers import AutoModel, pipeline, AutoTokenizer

import subprocess

# from issue: https://discuss.huggingface.co/t/how-to-install-flash-attention-on-hf-gradio-space/70698/2
# InternVL2 需要的 flash_attn 这个依赖只能这样运行时装
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

model_name = "OpenGVLab/InternVL2-8B"
model = (
    AutoModel.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        # low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    .eval()
    .cuda()
)

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    inference = pipeline(
        task="visual-question-answering", model=model, tokenizer=tokenizer
    )
except Exception as error:
    raise gr.Error("👌" + str(error), duration=30)


def predict(input_img, questions):
    try:
        gr.Info(str(type(inference)))
        predictions = inference(question=questions, image=input_img)
        return str(predictions)
    except Exception as e:
        # 捕获异常,并将错误信息转换为字符串
        error_message = "❌" + str(e)
        # 抛出gradio.Error来展示错误弹窗
        raise gr.Error(error_message, duration=25)


gradio_app = gr.Interface(
    predict,
    inputs=[
        gr.Image(label="Select A Image", sources=["upload", "webcam"], type="pil"),
        "text",
    ],
    outputs="text",
    title=str(type(inference)),
)

if __name__ == "__main__":
    gradio_app.launch(show_error=True, debug=True)