moriire commited on
Commit
0bb3f90
·
verified ·
1 Parent(s): da0aaeb

Update app/llm.py

Browse files
Files changed (1) hide show
  1. app/llm.py +12 -5
app/llm.py CHANGED
@@ -12,11 +12,12 @@ from fastapi import APIRouter
12
  from app.users import current_active_user
13
 
14
 
15
- from transformers import AutoModelForCausalLM
16
-
17
- model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
18
- model.to_bettertransformer()
19
 
 
 
20
 
21
  class GenModel(BaseModel):
22
  question: str
@@ -35,6 +36,7 @@ class ChatModel(BaseModel):
35
  mirostat_mode: int=2
36
  mirostat_tau: float=4.0
37
  mirostat_eta: float=1.1
 
38
  llm_chat = llama_cpp.Llama.from_pretrained(
39
  repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
40
  filename="*q4_0.gguf",
@@ -57,6 +59,7 @@ llm_generate = llama_cpp.Llama.from_pretrained(
57
  mirostat_eta=1.1,
58
  #chat_format="llama-2"
59
  )
 
60
  # Logger setup
61
  logging.basicConfig(level=logging.INFO)
62
  logger = logging.getLogger(__name__)
@@ -82,7 +85,11 @@ def health():
82
  # Chat Completion API
83
  @llm_router.post("/chat/", tags=["llm"])
84
  async def chat(chatm:ChatModel):#, user: schemas.BaseUser = fastapi.Depends(current_active_user)):
85
-
 
 
 
 
86
  """
87
  #chatm.system = chatm.system.format("")#user.email)
88
  try:
 
12
  from app.users import current_active_user
13
 
14
 
15
+ #from transformers import AutoModelForCausalLM
16
+ from transformers import AutoTokenizer, pipeline
17
+ from optimum.onnxruntime import ORTModelForQuestionAnswering
 
18
 
19
+ model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2")
20
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
21
 
22
  class GenModel(BaseModel):
23
  question: str
 
36
  mirostat_mode: int=2
37
  mirostat_tau: float=4.0
38
  mirostat_eta: float=1.1
39
+
40
  llm_chat = llama_cpp.Llama.from_pretrained(
41
  repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
42
  filename="*q4_0.gguf",
 
59
  mirostat_eta=1.1,
60
  #chat_format="llama-2"
61
  )
62
+
63
  # Logger setup
64
  logging.basicConfig(level=logging.INFO)
65
  logger = logging.getLogger(__name__)
 
85
  # Chat Completion API
86
  @llm_router.post("/chat/", tags=["llm"])
87
  async def chat(chatm:ChatModel):#, user: schemas.BaseUser = fastapi.Depends(current_active_user)):
88
+ onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
89
+ question = "What's my name?"
90
+ context = "My name is Philipp and I live in Nuremberg."
91
+ pred = onnx_qa(question, context)
92
+ return pred
93
  """
94
  #chatm.system = chatm.system.format("")#user.email)
95
  try: