|
""" |
|
modified by xsj |
|
This script implements an API for the ChatGLM3-6B model, |
|
formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat). |
|
It's designed to be run as a web server using FastAPI and uvicorn, |
|
making the ChatGLM3-6B model accessible through OpenAI Client. |
|
|
|
Key Components and Features: |
|
- Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them. |
|
- FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests. |
|
- API Endpoints: |
|
- "/v1/models": Lists the available models, specifically ChatGLM3-6B. |
|
- "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses. |
|
- "/v1/embeddings": Processes Embedding request of a list of text inputs. |
|
- Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'. |
|
For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output |
|
that many tokens after accounting for the history and prompt tokens. |
|
- Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses. |
|
- Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety. |
|
- Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port. |
|
|
|
Note: |
|
This script doesn't include the setup for special tokens or multi-GPU support by default. |
|
Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions. |
|
Embedding Models only support in One GPU. |
|
|
|
""" |
|
|
|
import os |
|
import time |
|
import tiktoken |
|
import torch |
|
import uvicorn |
|
|
|
from fastapi import FastAPI, HTTPException, Response |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from contextlib import asynccontextmanager |
|
from typing import List, Literal, Optional, Union |
|
from loguru import logger |
|
from pydantic import BaseModel, Field |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
from sse_starlette.sse import EventSourceResponse |
|
|
|
|
|
|
|
from NL2HLTLTranslator.mistral7b.prediction import Mistral_NL2TL_translator as NL2TL_translator |
|
|
|
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 |
|
|
|
output_dir = os.path.join(os.path.dirname(__file__),"../") |
|
tuned_model_name="mistral7b_quat8" |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
yield |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class ModelCard(BaseModel): |
|
id: str |
|
object: str = "model" |
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
owned_by: str = "owner" |
|
root: Optional[str] = None |
|
parent: Optional[str] = None |
|
permission: Optional[list] = None |
|
|
|
|
|
class ModelList(BaseModel): |
|
object: str = "list" |
|
data: List[ModelCard] = [] |
|
|
|
|
|
class FunctionCallResponse(BaseModel): |
|
name: Optional[str] = None |
|
arguments: Optional[str] = None |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
role: Literal["user", "assistant", "system", "function"] |
|
content: str = None |
|
name: Optional[str] = None |
|
function_call: Optional[FunctionCallResponse] = None |
|
|
|
|
|
class DeltaMessage(BaseModel): |
|
role: Optional[Literal["user", "assistant", "system"]] = None |
|
content: Optional[str] = None |
|
function_call: Optional[FunctionCallResponse] = None |
|
|
|
|
|
|
|
class EmbeddingRequest(BaseModel): |
|
input: List[str] |
|
model: str |
|
|
|
|
|
class CompletionUsage(BaseModel): |
|
prompt_tokens: int |
|
completion_tokens: int |
|
total_tokens: int |
|
|
|
|
|
class EmbeddingResponse(BaseModel): |
|
data: list |
|
model: str |
|
object: str |
|
usage: CompletionUsage |
|
|
|
|
|
|
|
|
|
class UsageInfo(BaseModel): |
|
prompt_tokens: int = 0 |
|
total_tokens: int = 0 |
|
completion_tokens: Optional[int] = 0 |
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
model: str |
|
messages: List[ChatMessage] |
|
temperature: Optional[float] = 0.8 |
|
top_p: Optional[float] = 0.8 |
|
max_tokens: Optional[int] = None |
|
stream: Optional[bool] = False |
|
tools: Optional[Union[dict, List[dict]]] = None |
|
repetition_penalty: Optional[float] = 1.1 |
|
|
|
|
|
class ChatCompletionResponseChoice(BaseModel): |
|
index: int |
|
message: ChatMessage |
|
finish_reason: Literal["stop", "length", "function_call"] |
|
|
|
|
|
class ChatCompletionResponseStreamChoice(BaseModel): |
|
delta: DeltaMessage |
|
finish_reason: Optional[Literal["stop", "length", "function_call"]] |
|
index: int |
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
model: str |
|
id: str |
|
object: Literal["chat.completion", "chat.completion.chunk"] |
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] |
|
created: Optional[int] = Field(default_factory=lambda: int(time.time())) |
|
usage: Optional[UsageInfo] = None |
|
|
|
|
|
@app.get("/health") |
|
async def health() -> Response: |
|
"""Health check.""" |
|
return Response(status_code=200) |
|
|
|
|
|
@app.post("/v1/embeddings", response_model=EmbeddingResponse) |
|
async def get_embeddings(request: EmbeddingRequest): |
|
embeddings = [embedding_model.encode(text) for text in request.input] |
|
embeddings = [embedding.tolist() for embedding in embeddings] |
|
|
|
def num_tokens_from_string(string: str) -> int: |
|
""" |
|
Returns the number of tokens in a text string. |
|
use cl100k_base tokenizer |
|
""" |
|
encoding = tiktoken.get_encoding('cl100k_base') |
|
num_tokens = len(encoding.encode(string)) |
|
return num_tokens |
|
|
|
response = { |
|
"data": [ |
|
{ |
|
"object": "embedding", |
|
"embedding": embedding, |
|
"index": index |
|
} |
|
for index, embedding in enumerate(embeddings) |
|
], |
|
"model": request.model, |
|
"object": "list", |
|
"usage": CompletionUsage( |
|
prompt_tokens=sum(len(text.split()) for text in request.input), |
|
completion_tokens=0, |
|
total_tokens=sum(num_tokens_from_string(text) for text in request.input), |
|
) |
|
} |
|
return response |
|
|
|
|
|
@app.get("/v1/models", response_model=ModelList) |
|
async def list_models(): |
|
model_card = ModelCard( |
|
id="chatglm3-6b" |
|
) |
|
return ModelList( |
|
data=[model_card] |
|
) |
|
|
|
count=0 |
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
|
async def create_chat_completion(request: ChatCompletionRequest): |
|
global model, tokenizer, LLM |
|
|
|
if len(request.messages) < 1 or request.messages[-1].role == "assistant": |
|
raise HTTPException(status_code=400, detail="Invalid request") |
|
|
|
gen_params = dict( |
|
messages=request.messages, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
max_tokens=request.max_tokens or 1024, |
|
echo=False, |
|
stream=request.stream, |
|
repetition_penalty=request.repetition_penalty, |
|
tools=request.tools, |
|
) |
|
logger.debug(f"==== request ====\n{gen_params}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response=LLM.translate(gen_params['messages'][0].content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
usage = UsageInfo() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function_call = None |
|
|
|
message = ChatMessage( |
|
role="assistant", |
|
content=response, |
|
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, |
|
) |
|
|
|
logger.debug(f"==== message ====\n{message}") |
|
|
|
choice_data = ChatCompletionResponseChoice( |
|
index=0, |
|
message=message, |
|
finish_reason='stop', |
|
) |
|
|
|
|
|
|
|
|
|
return ChatCompletionResponse( |
|
model=request.model, |
|
id="", |
|
choices=[choice_data], |
|
object="chat.completion", |
|
usage=usage |
|
) |
|
|
|
|
|
async def parse_output_text(model_id: str, value: str): |
|
""" |
|
Directly output the text content of value |
|
|
|
:param model_id: |
|
:param value: |
|
:return: |
|
""" |
|
choice_data = ChatCompletionResponseStreamChoice( |
|
index=0, |
|
delta=DeltaMessage(role="assistant", content=value), |
|
finish_reason=None |
|
) |
|
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") |
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) |
|
|
|
choice_data = ChatCompletionResponseStreamChoice( |
|
index=0, |
|
delta=DeltaMessage(), |
|
finish_reason="stop" |
|
) |
|
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") |
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) |
|
yield '[DONE]' |
|
|
|
|
|
def contains_custom_function(value: str) -> bool: |
|
""" |
|
Determine whether 'function_call' according to a special function prefix. |
|
|
|
For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_" |
|
|
|
[Note] This is not a rigorous judgment method, only for reference. |
|
|
|
:param value: |
|
:return: |
|
""" |
|
return value and 'get_' in value |
|
|
|
def run(output_dir = "path/to/model_weight", tuned_model_name="llama2_13b__mid_asciiaug1",CUDA_device='0',quat=True): |
|
global LLM |
|
LLM=NL2TL_translator(output_dir=output_dir,tuned_model_name= tuned_model_name,quat=quat) |
|
|
|
tokenizer = LLM.tokenizer |
|
model = LLM.model |
|
|
|
|
|
|
|
uvicorn.run(app, host='0.0.0.0', port=8001, workers=1) |
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
run(output_dir=output_dir,tuned_model_name=tuned_model_name) |
|
|
|
|