File size: 3,401 Bytes
5bb4c9e d80b380 5bb4c9e d80b380 500acbd 5bb4c9e d80b380 cb9be9f d80b380 5bb4c9e 500acbd 5bb4c9e 43d6f9a 500acbd d80b380 cb9be9f d80b380 5bb4c9e d80b380 500acbd d80b380 cb9be9f 500acbd 5bb4c9e d80b380 5bb4c9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import time
import json
from typing import List, Literal
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from gradio_client import Client
app = FastAPI()
client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
class Message(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class Payload(BaseModel):
stream: bool = False
model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2"
messages: List[Message]
temperature: float = 0.9
frequency_penalty: float = 1.2
top_p: float = 0.9
async def stream(iter):
while True:
try:
value = await asyncio.to_thread(iter.__next__)
yield value
except StopIteration:
break
def make_chunk_obj(i, delta, fr):
return {
"id": str(time.time_ns()),
"object": "chat.completion.chunk",
"created": round(time.time()),
"model": "mistral-7b-instruct-v0.2",
"system_fingerprint": "wtf",
"choices": [
{
"index": i,
"delta": {
"content": delta
},
"finish_reason": fr
}
]
}
@app.get('/')
async def index():
return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" })
@app.post('/chat/completions')
async def c_cmp(payload: Payload):
if not payload.stream:
return JSONResponse(
{
"id": str(time.time_ns()),
"object": "chat.completion",
"created": round(time.time()),
"model": payload.model,
"system_fingerprint": "wtf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": client.predict(
payload.messages.model_dump_json(),
payload.temperature,
4096,
payload.top_p,
payload.frequency_penalty,
api_name="/chat"
)
}
}
]
}
)
def streamer():
text = ""
result = client.submit(
payload.messages.model_dump_json(),
payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
payload.frequency_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
api_name="/chat"
)
for i, item in enumerate(result):
delta = item[len(text):]
yield "data: " + json.dumps(
make_chunk_obj(i, delta, None)
)
text = item
yield "data: " + json.dumps(make_chunk_obj(i, "", "stop"))
yield "data: [END]"
return StreamingResponse(streamer())
|