|
import argparse |
|
import gradio as gr |
|
from openai import OpenAI |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
description='Chatbot Interface with Customizable Parameters') |
|
parser.add_argument('--model-url', |
|
type=str, |
|
default='http://134.28.190.100:8000/v1', |
|
help='Model URL') |
|
parser.add_argument('-m', |
|
'--model', |
|
type=str, |
|
required=True, |
|
default='TheBloke/Mistral-7B-Instruct-v0.2-AWQ', |
|
help='Model name for the chatbot') |
|
parser.add_argument('--temp', |
|
type=float, |
|
default=0.8, |
|
help='Temperature for text generation') |
|
parser.add_argument('--stop-token-ids', |
|
type=str, |
|
default='', |
|
help='Comma-separated stop token IDs') |
|
parser.add_argument("--host", type=str, default=None) |
|
parser.add_argument("--port", type=int, default=8001) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
openai_api_key = "EMPTY" |
|
openai_api_base = args.model_url |
|
|
|
|
|
client = OpenAI( |
|
api_key=openai_api_key, |
|
base_url=openai_api_base, |
|
) |
|
|
|
|
|
|
|
def predict(message, history): |
|
|
|
history_openai_format = [] |
|
|
|
|
|
|
|
for human, assistant in history: |
|
history_openai_format.append({"role": "user", "content": human}) |
|
history_openai_format.append({ |
|
"role": "assistant", |
|
"content": assistant |
|
}) |
|
history_openai_format.append({"role": "user", "content": message}) |
|
|
|
|
|
stream = client.chat.completions.create( |
|
model=args.model, |
|
messages=history_openai_format, |
|
temperature=args.temp, |
|
stream=True, |
|
extra_body={ |
|
'repetition_penalty': |
|
1, |
|
'stop_token_ids': [ |
|
int(id.strip()) for id in args.stop_token_ids.split(',') |
|
if id.strip() |
|
] if args.stop_token_ids else [] |
|
}) |
|
|
|
|
|
partial_message = "" |
|
for chunk in stream: |
|
partial_message += (chunk.choices[0].delta.content or "") |
|
yield partial_message |
|
|
|
with gr.Blocks(title="MethodAI 0.15", theme="Soft") as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.UploadButton("Click to upload PDFs",file_types=[".pdf"]) |
|
with gr.Column(scale=4): |
|
|
|
gr.ChatInterface(predict).queue() |
|
|
|
|
|
demo.launch(server_name=args.host, server_port=args.port, share=True) |
|
|