File size: 2,029 Bytes
8273d5f
2a438ba
8273d5f
9ce3948
580cc25
 
695df9a
580cc25
 
 
 
 
 
9ce3948
 
 
 
 
 
 
 
 
 
 
 
a41d9f9
 
580cc25
9ce3948
580cc25
 
 
 
 
2a438ba
 
9ce3948
8273d5f
 
9ce3948
 
8273d5f
a41d9f9
8273d5f
 
580cc25
8273d5f
 
2a438ba
d652f80
 
 
8273d5f
a41d9f9
8273d5f
d652f80
8273d5f
695df9a
d652f80
 
 
8273d5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import gradio as gr
from transformers import AutoModel, pipeline, AutoTokenizer
import spaces
import subprocess

# from issue: https://discuss.huggingface.co/t/how-to-install-flash-attention-on-hf-gradio-space/70698/2
# InternVL2 需要的 flash_attn 这个依赖只能这样运行时装
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)
try:
    model_name = "OpenGVLab/InternVL2-8B"
    # model: <class 'transformers_modules.OpenGVLab.InternVL2-8B.0e6d592d957d9739b6df0f4b90be4cb0826756b9.modeling_internvl_chat.InternVLChatModel'>
    model = (
        AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            # low_cpu_mem_usage=True,
            trust_remote_code=True,
        )
        .cuda()
        .eval()
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    # pipeline: <class 'transformers.pipelines.visual_question_answering.VisualQuestionAnsweringPipeline'>
    inference = pipeline(
        task="visual-question-answering", model=model, tokenizer=tokenizer
    )
except Exception as error:
    raise gr.Error("👌" + str(error), duration=30)


@spaces.GPU
def predict(input_img, questions):
    try:
        gr.Info("pipeline: " + str(type(inference)))
        gr.Info("model: " + str(type(model)))
        predictions = inference(question=questions, image=input_img)
        return str(predictions)
    except Exception as e:
        # 捕获异常,并将错误信息转换为字符串
        error_message = "❌" + str(e)
        # 抛出gradio.Error来展示错误弹窗
        raise gr.Error(error_message, duration=25)


gradio_app = gr.Interface(
    predict,
    inputs=[
        gr.Image(label="Select A Image", sources=["upload", "webcam"], type="pil"),
        "text",
    ],
    outputs="text",
    title=str(type(inference)),
)

if __name__ == "__main__":
    gradio_app.launch(show_error=True, debug=True)