|
import time |
|
import json |
|
from typing import List, Literal |
|
|
|
from fastapi import FastAPI |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
|
|
from gradio_client import Client |
|
|
|
app = FastAPI() |
|
client = Client("AWeirdDev/mistral-7b-instruct-v0.2") |
|
|
|
class Message(BaseModel): |
|
role: Literal["user", "assistant", "system"] |
|
content: str |
|
|
|
class Payload(BaseModel): |
|
stream: bool = False |
|
model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2" |
|
messages: List[Message] |
|
temperature: float = 0.9 |
|
frequency_penalty: float = 1.2 |
|
top_p: float = 0.9 |
|
|
|
async def stream(iter): |
|
while True: |
|
try: |
|
value = await asyncio.to_thread(iter.__next__) |
|
yield value |
|
except StopIteration: |
|
break |
|
|
|
def make_chunk_obj(i, delta, fr): |
|
return { |
|
"id": str(time.time_ns()), |
|
"object": "chat.completion.chunk", |
|
"created": round(time.time()), |
|
"model": "mistral-7b-instruct-v0.2", |
|
"system_fingerprint": "wtf", |
|
"choices": [ |
|
{ |
|
"index": i, |
|
"delta": { |
|
"content": delta |
|
}, |
|
"finish_reason": fr |
|
} |
|
] |
|
} |
|
|
|
@app.get('/') |
|
async def index(): |
|
return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" }) |
|
|
|
@app.post('/chat/completions') |
|
async def c_cmp(payload: Payload): |
|
if not payload.stream: |
|
return JSONResponse( |
|
{ |
|
"id": str(time.time_ns()), |
|
"object": "chat.completion", |
|
"created": round(time.time()), |
|
"model": payload.model, |
|
"system_fingerprint": "wtf", |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": client.predict( |
|
payload.model_dump()['messages'], |
|
payload.temperature, |
|
4096, |
|
payload.top_p, |
|
payload.frequency_penalty, |
|
api_name="/chat" |
|
) |
|
} |
|
} |
|
] |
|
} |
|
) |
|
|
|
|
|
def streamer(): |
|
text = "" |
|
result = client.submit( |
|
payload.model_dump()['messages'], |
|
payload.temperature, |
|
4096, |
|
payload.top_p, |
|
payload.frequency_penalty, |
|
api_name="/chat" |
|
) |
|
for i, item in enumerate(result): |
|
delta = item[len(text):] |
|
yield "data: " + json.dumps( |
|
make_chunk_obj(i, delta, None) |
|
) |
|
text = item |
|
|
|
yield "data: " + json.dumps(make_chunk_obj(i, "", "stop")) |
|
yield "data: [END]" |
|
|
|
return StreamingResponse(streamer()) |
|
|