Linly-ChatFlow / app.py
wmpscc's picture
Update app.py
7c04cb8
raw
history blame
2.82 kB
import torch
import gradio as gr
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaForCausalLM, LlamaForTokenizer
from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration
from huggingface_hub import hf_hub_download
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
args = None
lm_generation = None
def init_args():
global args
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
args = parser.parse_args()
args.load_model_path = 'Linly-AI/ChatFlow-13B'
#args.load_model_path = 'Linly-AI/ChatFlow-7B'
# args.load_model_path = './model_file/chatllama_7b.bin'
#args.config_path = './config/llama_7b.json'
#args.load_model_path = './model_file/chatflow_13b.bin'
args.config_path = './config/llama_13b_config.json'
args.spm_model_path = './model_file/tokenizer.model'
args.batch_size = 1
args.seq_length = 1024
args.world_size = 1
args.use_int8 = True
args.top_p = 0
args.repetition_penalty_range = 1024
args.repetition_penalty_slope = 0
args.repetition_penalty = 1.15
args = load_hyperparam(args)
# args.tokenizer = Tokenizer(model_path=args.spm_model_path)
args.tokenizer = LlamaForTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", trust_remote_code=True)
args.vocab_size = args.tokenizer.sp_model.vocab_size()
def init_model():
global lm_generation
# torch.set_default_tensor_type(torch.HalfTensor)
# model = LLaMa(args)
# torch.set_default_tensor_type(torch.FloatTensor)
# # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
# args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
# model = load_model(model, args.load_model_path)
# model.eval()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
model = LlamaForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
print(model)
print(torch.cuda.max_memory_allocated() / 1024 ** 3)
lm_generation = LmGeneration(model, args.tokenizer)
def chat(prompt, top_k, temperature):
args.top_k = int(top_k)
args.temperature = temperature
response = lm_generation.generate(args, [prompt])
print('log:', response[0])
return response[0]
if __name__ == '__main__':
init_args()
init_model()
demo = gr.Interface(
fn=chat,
inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
outputs="text",
)
demo.launch()