MAmmoTH-VL2 / app_test.py
wenhu's picture
Update app_test.py
c2ad9fa verified
raw
history blame
4.29 kB
# from .demo_modelpart import InferenceDemo
import gradio as gr
import os
from threading import Thread
# import time
import cv2
import datetime
# import copy
import torch
import spaces
import numpy as np
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
tokenizer_image_token,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
from decord import VideoReader, cpu
import requests
from PIL import Image
import io
from io import BytesIO
from transformers import TextStreamer, TextIteratorStreamer
import hashlib
import PIL
import base64
import json
import datetime
import gradio as gr
import gradio_client
import subprocess
import sys
from huggingface_hub import HfApi
from huggingface_hub import login
from huggingface_hub import revision_exists
login(token=os.environ["HF_TOKEN"],
write_permission=True)
api = HfApi()
repo_name = os.environ["LOG_REPO"]
external_log_dir = "./logs"
LOGDIR = external_log_dir
VOTEDIR = "./votes"
with gr.Blocks(
css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}",
) as demo:
cur_dir = os.path.dirname(os.path.abspath(__file__))
# gr.Markdown(title_markdown)
gr.HTML(html_header)
with gr.Column():
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=8192,
value=4096,
step=256,
interactive=True,
label="Max output tokens",
)
with gr.Row():
chatbot = gr.Chatbot([], elem_id="MAmmoTH-VL-8B", bubble_full_width=False, height=750)
with gr.Row():
upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True)
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True)
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
demo.queue()
if __name__ == "__main__":
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument("--server_name", default="0.0.0.0", type=str)
argparser.add_argument("--model_path", default="TIGER-Lab/MAmmoTH-VL2", type=str)
argparser.add_argument("--model-base", type=str, default=None)
argparser.add_argument("--num-gpus", type=int, default=1)
argparser.add_argument("--conv-mode", type=str, default=None)
argparser.add_argument("--temperature", type=float, default=0.7)
argparser.add_argument("--max-new-tokens", type=int, default=4096)
argparser.add_argument("--num_frames", type=int, default=32)
argparser.add_argument("--load-8bit", action="store_true")
argparser.add_argument("--load-4bit", action="store_true")
argparser.add_argument("--debug", action="store_true")
args = argparser.parse_args()
model_path = args.model_path
filt_invalid = "cut"
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
model=model.to(torch.device('cuda'))
chat_image_num = 0
demo.launch()