AWeirdDev commited on
Commit
f4fee37
·
verified ·
1 Parent(s): 919478a

use hf api instead

Browse files
Files changed (1) hide show
  1. app.py +50 -4
app.py CHANGED
@@ -6,13 +6,17 @@ from fastapi import FastAPI
6
  from fastapi.responses import StreamingResponse, JSONResponse
7
  from pydantic import BaseModel
8
 
9
- from gradio_client import Client
 
 
 
 
10
 
11
  app = FastAPI()
12
  client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
13
 
14
  class Message(BaseModel):
15
- role: Literal["user", "assistant", "system"]
16
  content: str
17
 
18
  class Payload(BaseModel):
@@ -31,6 +35,17 @@ async def stream(iter):
31
  except StopIteration:
32
  break
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  def make_chunk_obj(i, delta, fr):
35
  return {
36
  "id": str(time.time_ns()),
@@ -49,6 +64,37 @@ def make_chunk_obj(i, delta, fr):
49
  ]
50
  }
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @app.get('/')
53
  async def index():
54
  return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" })
@@ -68,7 +114,7 @@ async def c_cmp(payload: Payload):
68
  "index": 0,
69
  "message": {
70
  "role": "assistant",
71
- "content": client.predict(
72
  payload.model_dump()['messages'],
73
  payload.temperature,
74
  4096,
@@ -85,7 +131,7 @@ async def c_cmp(payload: Payload):
85
 
86
  def streamer():
87
  text = ""
88
- result = client.submit(
89
  payload.model_dump()['messages'],
90
  payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
91
  4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
 
6
  from fastapi.responses import StreamingResponse, JSONResponse
7
  from pydantic import BaseModel
8
 
9
+ from huggingface_hub import InferenceClient
10
+
11
+ client = InferenceClient(
12
+ "mistralai/Mistral-7B-Instruct-v0.2"
13
+ )
14
 
15
  app = FastAPI()
16
  client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
17
 
18
  class Message(BaseModel):
19
+ role: Literal["user", "assistant"]
20
  content: str
21
 
22
  class Payload(BaseModel):
 
35
  except StopIteration:
36
  break
37
 
38
+ def format_prompt(messages: List[Message]):
39
+ prompt = "<s>"
40
+
41
+ for message in messages:
42
+ if message['role'] == 'user':
43
+ prompt += f"[INST] {message['content']} [/INST]"
44
+ else:
45
+ prompt += f" {message['content']}</s> "
46
+
47
+ return prompt
48
+
49
  def make_chunk_obj(i, delta, fr):
50
  return {
51
  "id": str(time.time_ns()),
 
64
  ]
65
  }
66
 
67
+ def generate(
68
+ messages,
69
+ temperature=0.9,
70
+ max_new_tokens=256,
71
+ top_p=0.95,
72
+ repetition_penalty=1.0,
73
+ ):
74
+ temperature = float(temperature)
75
+ if temperature < 1e-2:
76
+ temperature = 1e-2
77
+ top_p = float(top_p)
78
+
79
+ generate_kwargs = dict(
80
+ temperature=temperature,
81
+ max_new_tokens=max_new_tokens,
82
+ top_p=top_p,
83
+ repetition_penalty=repetition_penalty,
84
+ do_sample=True,
85
+ seed=None
86
+ )
87
+
88
+ formatted_prompt = format_prompt(messages)
89
+
90
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
91
+
92
+ for response in stream:
93
+ t = response.token.text
94
+ yield t if t != "</s>" else ""
95
+
96
+ #return output
97
+
98
  @app.get('/')
99
  async def index():
100
  return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" })
 
114
  "index": 0,
115
  "message": {
116
  "role": "assistant",
117
+ "content": generate(
118
  payload.model_dump()['messages'],
119
  payload.temperature,
120
  4096,
 
131
 
132
  def streamer():
133
  text = ""
134
+ result = generate(
135
  payload.model_dump()['messages'],
136
  payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
137
  4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component