mgbam commited on
Commit
0181a1f
·
verified ·
1 Parent(s): b412ca5

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +53 -0
inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+
3
+ from typing import List, Dict, Generator, Optional
4
+ from hf_client import get_inference_client
5
+
6
+ def chat_completion(
7
+ model_id: str,
8
+ messages: List[Dict[str, str]],
9
+ provider: Optional[str] = None,
10
+ max_tokens: int = 4096
11
+ ) -> str:
12
+ """
13
+ Send a chat completion request to the appropriate inference provider.
14
+
15
+ Args:
16
+ model_id: The model identifier to use.
17
+ messages: A list of OpenAI‑style {'role':'...', 'content':'...'} messages.
18
+ provider: Optional override for provider; uses model default if None.
19
+ max_tokens: Maximum tokens to generate.
20
+
21
+ Returns:
22
+ The assistant's response content.
23
+ """
24
+ client = get_inference_client(model_id, provider or "auto")
25
+ response = client.chat.completions.create(
26
+ model=model_id,
27
+ messages=messages,
28
+ max_tokens=max_tokens
29
+ )
30
+ return response.choices[0].message.content
31
+
32
+
33
+ def stream_chat_completion(
34
+ model_id: str,
35
+ messages: List[Dict[str, str]],
36
+ provider: Optional[str] = None,
37
+ max_tokens: int = 4096
38
+ ) -> Generator[str, None, None]:
39
+ """
40
+ Generator for streaming chat completions.
41
+ Yields partial message chunks as strings.
42
+ """
43
+ client = get_inference_client(model_id, provider or "auto")
44
+ stream = client.chat.completions.create(
45
+ model=model_id,
46
+ messages=messages,
47
+ max_tokens=max_tokens,
48
+ stream=True
49
+ )
50
+ for chunk in stream:
51
+ delta = getattr(chunk.choices[0].delta, "content", None)
52
+ if delta:
53
+ yield delta