Spaces:
Sleeping
Sleeping
huggingface model support draft
Browse files
src/know_lang_bot/chat_bot/__main__.py
CHANGED
@@ -19,7 +19,6 @@ async def test_chat_processing():
|
|
19 |
)
|
20 |
|
21 |
print(f"Answer: {result.answer}")
|
22 |
-
print(f"References: {result.references_md}")
|
23 |
|
24 |
if __name__ == "__main__":
|
25 |
asyncio.run(test_chat_processing())
|
|
|
19 |
)
|
20 |
|
21 |
print(f"Answer: {result.answer}")
|
|
|
22 |
|
23 |
if __name__ == "__main__":
|
24 |
asyncio.run(test_chat_processing())
|
src/know_lang_bot/chat_bot/chat_graph.py
CHANGED
@@ -12,8 +12,11 @@ from pydantic_ai import Agent
|
|
12 |
import logfire
|
13 |
from pprint import pformat
|
14 |
from enum import Enum
|
|
|
|
|
15 |
|
16 |
LOG = FancyLogger(__name__)
|
|
|
17 |
|
18 |
class ChatStatus(str, Enum):
|
19 |
"""Enum for tracking chat progress status"""
|
@@ -128,7 +131,10 @@ Example Output: "Where is the configuration file or configuration settings store
|
|
128 |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
|
129 |
# Create an agent for question polishing
|
130 |
polish_agent = Agent(
|
131 |
-
|
|
|
|
|
|
|
132 |
system_prompt=self.system_prompt
|
133 |
)
|
134 |
prompt = f"""Original question: "{ctx.state.original_question}"
|
@@ -203,7 +209,10 @@ Remember: Your primary goal is answering the user's specific question, not expla
|
|
203 |
|
204 |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
|
205 |
answer_agent = Agent(
|
206 |
-
|
|
|
|
|
|
|
207 |
system_prompt=self.system_prompt
|
208 |
)
|
209 |
|
@@ -267,6 +276,8 @@ async def process_chat(
|
|
267 |
)
|
268 |
except Exception as e:
|
269 |
LOG.error(f"Error processing chat in graph: {e}")
|
|
|
|
|
270 |
result = ChatResult(
|
271 |
answer="I encountered an error processing your question. Please try again."
|
272 |
)
|
|
|
12 |
import logfire
|
13 |
from pprint import pformat
|
14 |
from enum import Enum
|
15 |
+
from rich.console import Console
|
16 |
+
from know_lang_bot.utils.model_provider import create_pydantic_model
|
17 |
|
18 |
LOG = FancyLogger(__name__)
|
19 |
+
console = Console()
|
20 |
|
21 |
class ChatStatus(str, Enum):
|
22 |
"""Enum for tracking chat progress status"""
|
|
|
131 |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
|
132 |
# Create an agent for question polishing
|
133 |
polish_agent = Agent(
|
134 |
+
create_pydantic_model(
|
135 |
+
model_provider=ctx.deps.config.llm.model_provider,
|
136 |
+
model_name=ctx.deps.config.llm.model_name
|
137 |
+
),
|
138 |
system_prompt=self.system_prompt
|
139 |
)
|
140 |
prompt = f"""Original question: "{ctx.state.original_question}"
|
|
|
209 |
|
210 |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
|
211 |
answer_agent = Agent(
|
212 |
+
create_pydantic_model(
|
213 |
+
model_provider=ctx.deps.config.llm.model_provider,
|
214 |
+
model_name=ctx.deps.config.llm.model_name
|
215 |
+
),
|
216 |
system_prompt=self.system_prompt
|
217 |
)
|
218 |
|
|
|
276 |
)
|
277 |
except Exception as e:
|
278 |
LOG.error(f"Error processing chat in graph: {e}")
|
279 |
+
console.print_exception()
|
280 |
+
|
281 |
result = ChatResult(
|
282 |
answer="I encountered an error processing your question. Please try again."
|
283 |
)
|
src/know_lang_bot/models/huggingface.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from datetime import datetime, timezone
|
6 |
+
from typing import AsyncIterator, Optional
|
7 |
+
from contextlib import asynccontextmanager
|
8 |
+
|
9 |
+
from pydantic_ai.models import (
|
10 |
+
AgentModel,
|
11 |
+
ModelMessage,
|
12 |
+
ModelResponse,
|
13 |
+
StreamedResponse,
|
14 |
+
Usage,
|
15 |
+
Model,
|
16 |
+
check_allow_model_requests
|
17 |
+
)
|
18 |
+
from pydantic_ai.tools import ToolDefinition
|
19 |
+
from pydantic_ai.messages import (
|
20 |
+
ModelResponsePart, TextPart, SystemPromptPart,
|
21 |
+
UserPromptPart, ToolReturnPart, ModelResponseStreamEvent
|
22 |
+
)
|
23 |
+
from pydantic_ai.settings import ModelSettings
|
24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
25 |
+
import torch
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass(init=False)
|
29 |
+
class HuggingFaceModel(Model):
|
30 |
+
"""A model that uses HuggingFace models locally.
|
31 |
+
|
32 |
+
For MVP, this implements basic functionality without streaming or advanced features.
|
33 |
+
"""
|
34 |
+
|
35 |
+
model_name: str
|
36 |
+
model: AutoModelForCausalLM = field(repr=False)
|
37 |
+
tokenizer: AutoTokenizer = field(repr=False)
|
38 |
+
device: str = field(default="cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
model_name: str,
|
43 |
+
*,
|
44 |
+
device: Optional[str] = None,
|
45 |
+
max_new_tokens: int = 512,
|
46 |
+
):
|
47 |
+
"""Initialize a HuggingFace model.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
model_name: Name of the model on HuggingFace Hub
|
51 |
+
device: Device to run model on ('cuda' or 'cpu')
|
52 |
+
max_new_tokens: Maximum number of tokens to generate
|
53 |
+
"""
|
54 |
+
self.model_name = model_name
|
55 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
+
self.max_new_tokens = max_new_tokens
|
57 |
+
|
58 |
+
# Load model and tokenizer
|
59 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
60 |
+
self.model.to(self.device)
|
61 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
62 |
+
if not self.tokenizer.pad_token:
|
63 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
64 |
+
|
65 |
+
async def agent_model(
|
66 |
+
self,
|
67 |
+
*,
|
68 |
+
function_tools: list[ToolDefinition],
|
69 |
+
allow_text_result: bool,
|
70 |
+
result_tools: list[ToolDefinition],
|
71 |
+
) -> AgentModel:
|
72 |
+
"""Create an agent model for each step of an agent run."""
|
73 |
+
check_allow_model_requests()
|
74 |
+
|
75 |
+
return HuggingFaceAgentModel(
|
76 |
+
model=self.model,
|
77 |
+
tokenizer=self.tokenizer,
|
78 |
+
model_name=self.model_name,
|
79 |
+
device=self.device,
|
80 |
+
max_new_tokens=self.max_new_tokens,
|
81 |
+
tools=function_tools if function_tools else None
|
82 |
+
)
|
83 |
+
|
84 |
+
def name(self) -> str:
|
85 |
+
return f"huggingface:{self.model_name}"
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class HuggingFaceAgentModel(AgentModel):
|
89 |
+
"""Implementation of AgentModel for HuggingFace models."""
|
90 |
+
|
91 |
+
model: AutoModelForCausalLM
|
92 |
+
tokenizer: AutoTokenizer
|
93 |
+
model_name: str
|
94 |
+
device: str
|
95 |
+
max_new_tokens: int
|
96 |
+
tools: Optional[list[ToolDefinition]] = None
|
97 |
+
|
98 |
+
def _format_messages(self, messages: list[ModelMessage]) -> str:
|
99 |
+
"""Format messages into a prompt the model can understand."""
|
100 |
+
formatted = []
|
101 |
+
for message in messages:
|
102 |
+
for part in message.parts:
|
103 |
+
if isinstance(part, SystemPromptPart):
|
104 |
+
formatted.append(f"<|system|>{part.content}")
|
105 |
+
elif isinstance(part, UserPromptPart):
|
106 |
+
formatted.append(f"<|user|>{part.content}")
|
107 |
+
else:
|
108 |
+
# For MVP, we'll just pass through other message types
|
109 |
+
formatted.append(str(part.content))
|
110 |
+
formatted.append("<|assistant|>")
|
111 |
+
return "\n".join(formatted)
|
112 |
+
|
113 |
+
async def request(
|
114 |
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
115 |
+
) -> tuple[ModelResponse, Usage]:
|
116 |
+
"""Make a request to the model."""
|
117 |
+
prompt = self._format_messages(messages)
|
118 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
outputs = self.model.generate(
|
122 |
+
**inputs,
|
123 |
+
max_new_tokens=self.max_new_tokens,
|
124 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
125 |
+
)
|
126 |
+
|
127 |
+
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
128 |
+
response_text = response_text[len(prompt):] # Remove the prompt
|
129 |
+
|
130 |
+
timestamp = datetime.now(timezone.utc)
|
131 |
+
response = ModelResponse(
|
132 |
+
parts=[TextPart(response_text)],
|
133 |
+
model_name=self.model_name,
|
134 |
+
timestamp=timestamp
|
135 |
+
)
|
136 |
+
|
137 |
+
usage_stats = Usage(
|
138 |
+
requests=1,
|
139 |
+
request_tokens=inputs.input_ids.shape[1],
|
140 |
+
response_tokens=len(outputs[0]) - inputs.input_ids.shape[1],
|
141 |
+
total_tokens=len(outputs[0])
|
142 |
+
)
|
143 |
+
|
144 |
+
return response, usage_stats
|
145 |
+
|
146 |
+
@asynccontextmanager
|
147 |
+
async def request_stream(
|
148 |
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
149 |
+
) -> AsyncIterator[StreamedResponse]:
|
150 |
+
"""Make a streaming request to the model."""
|
151 |
+
prompt = self._format_messages(messages)
|
152 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
153 |
+
|
154 |
+
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
|
155 |
+
|
156 |
+
generation_kwargs = dict(
|
157 |
+
input_ids=inputs.input_ids,
|
158 |
+
max_new_tokens=self.max_new_tokens,
|
159 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
160 |
+
streamer=streamer,
|
161 |
+
)
|
162 |
+
|
163 |
+
thread = torch.jit.fork(self.model.generate, **generation_kwargs)
|
164 |
+
|
165 |
+
try:
|
166 |
+
yield HuggingFaceStreamedResponse(
|
167 |
+
_model_name=self.model_name,
|
168 |
+
_streamer=streamer,
|
169 |
+
_timestamp=datetime.now(timezone.utc),
|
170 |
+
)
|
171 |
+
finally:
|
172 |
+
torch.jit.wait(thread)
|
173 |
+
|
174 |
+
@dataclass
|
175 |
+
class HuggingFaceStreamedResponse(StreamedResponse):
|
176 |
+
"""Implementation of StreamedResponse for HuggingFace models."""
|
177 |
+
|
178 |
+
_streamer: TextIteratorStreamer
|
179 |
+
_timestamp: datetime
|
180 |
+
|
181 |
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
182 |
+
"""Stream tokens from the model."""
|
183 |
+
for new_text in self._streamer:
|
184 |
+
self._usage.response_tokens += len(new_text)
|
185 |
+
self._usage.total_tokens = self._usage.request_tokens + self._usage.response_tokens
|
186 |
+
|
187 |
+
yield self._parts_manager.handle_text_delta(
|
188 |
+
vendor_part_id='content',
|
189 |
+
content=new_text
|
190 |
+
)
|
191 |
+
|
192 |
+
def timestamp(self) -> datetime:
|
193 |
+
return self._timestamp
|
src/know_lang_bot/summarizer/summarizer.py
CHANGED
@@ -10,6 +10,7 @@ from rich.progress import Progress
|
|
10 |
from know_lang_bot.config import AppConfig
|
11 |
from know_lang_bot.core.types import CodeChunk, ModelProvider
|
12 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
|
|
13 |
|
14 |
LOG = FancyLogger(__name__)
|
15 |
|
@@ -50,7 +51,10 @@ Provide a clean, concise and focused summary. Don't include unnecessary nor gene
|
|
50 |
"""
|
51 |
|
52 |
self.agent = Agent(
|
53 |
-
|
|
|
|
|
|
|
54 |
system_prompt=system_prompt,
|
55 |
model_settings=self.config.llm.model_settings
|
56 |
)
|
|
|
10 |
from know_lang_bot.config import AppConfig
|
11 |
from know_lang_bot.core.types import CodeChunk, ModelProvider
|
12 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
13 |
+
from know_lang_bot.utils.model_provider import create_pydantic_model
|
14 |
|
15 |
LOG = FancyLogger(__name__)
|
16 |
|
|
|
51 |
"""
|
52 |
|
53 |
self.agent = Agent(
|
54 |
+
create_pydantic_model(
|
55 |
+
model_provider=self.config.llm.model_provider,
|
56 |
+
model_name=self.config.llm.model_name
|
57 |
+
),
|
58 |
system_prompt=system_prompt,
|
59 |
model_settings=self.config.llm.model_settings
|
60 |
)
|
src/know_lang_bot/utils/model_provider.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic_ai.models import Model, KnownModelName
|
2 |
+
from know_lang_bot.core.types import ModelProvider
|
3 |
+
from know_lang_bot.models.huggingface import HuggingFaceModel
|
4 |
+
from typing import get_args
|
5 |
+
|
6 |
+
def create_pydantic_model(
|
7 |
+
model_provider: ModelProvider,
|
8 |
+
model_name: str,
|
9 |
+
) -> Model | KnownModelName:
|
10 |
+
model_str = f"{model_provider}:{model_name}"
|
11 |
+
|
12 |
+
if model_str in get_args(KnownModelName):
|
13 |
+
return model_str
|
14 |
+
elif model_provider == ModelProvider.HUGGINGFACE:
|
15 |
+
return HuggingFaceModel(model_name=model_name)
|
16 |
+
else:
|
17 |
+
raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")
|