asimsultan commited on
Commit
540f1c5
·
1 Parent(s): 5898430

Updated docker file

Browse files
Files changed (1) hide show
  1. app.py +11 -25
app.py CHANGED
@@ -1,10 +1,10 @@
1
  from fastapi import FastAPI, Request
2
- from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  from huggingface_hub import InferenceClient
6
  import os
7
 
 
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  MODEL_ID = "google/gemma-2b-it"
10
  client = InferenceClient(token=HF_TOKEN)
@@ -23,29 +23,15 @@ app.add_middleware(
23
  class ChatRequest(BaseModel):
24
  message: str
25
 
26
- # @app.post("/chat")
27
- # async def chat(req: ChatRequest):
28
- # try:
29
- # messages = [{"role": "user", "content": req.message}]
30
- # response = client.chat_completion(
31
- # model=MODEL_ID,
32
- # messages=messages,
33
- # temperature=0.7,
34
- # )
35
- # return {"response": response.choices[0].message.content}
36
- # except Exception as e:
37
- # return {"error": str(e)}
38
-
39
-
40
- async def chat_endpoint(data: ChatRequest):
41
- def stream():
42
- for chunk in client.text_generation(
43
  model=MODEL_ID,
44
- prompt=data.message,
45
- stream=True,
46
- max_new_tokens=512,
47
  temperature=0.7,
48
- ):
49
- yield chunk
50
-
51
- return StreamingResponse(stream(), media_type="text/plain")
 
1
  from fastapi import FastAPI, Request
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import os
6
 
7
+
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  MODEL_ID = "google/gemma-2b-it"
10
  client = InferenceClient(token=HF_TOKEN)
 
23
  class ChatRequest(BaseModel):
24
  message: str
25
 
26
+ @app.post("/chat")
27
+ async def chat(req: ChatRequest):
28
+ try:
29
+ messages = [{"role": "user", "content": req.message}]
30
+ response = client.chat_completion(
 
 
 
 
 
 
 
 
 
 
 
 
31
  model=MODEL_ID,
32
+ messages=messages,
 
 
33
  temperature=0.7,
34
+ )
35
+ return {"response": response.choices[0].message.content}
36
+ except Exception as e:
37
+ return {"error": str(e)}