gabykim commited on
Commit
183e719
·
1 Parent(s): f1afcb2

huggingface model support draft

Browse files
src/know_lang_bot/chat_bot/__main__.py CHANGED
@@ -19,7 +19,6 @@ async def test_chat_processing():
19
  )
20
 
21
  print(f"Answer: {result.answer}")
22
- print(f"References: {result.references_md}")
23
 
24
  if __name__ == "__main__":
25
  asyncio.run(test_chat_processing())
 
19
  )
20
 
21
  print(f"Answer: {result.answer}")
 
22
 
23
  if __name__ == "__main__":
24
  asyncio.run(test_chat_processing())
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -12,8 +12,11 @@ from pydantic_ai import Agent
12
  import logfire
13
  from pprint import pformat
14
  from enum import Enum
 
 
15
 
16
  LOG = FancyLogger(__name__)
 
17
 
18
  class ChatStatus(str, Enum):
19
  """Enum for tracking chat progress status"""
@@ -128,7 +131,10 @@ Example Output: "Where is the configuration file or configuration settings store
128
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
129
  # Create an agent for question polishing
130
  polish_agent = Agent(
131
- f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
 
 
 
132
  system_prompt=self.system_prompt
133
  )
134
  prompt = f"""Original question: "{ctx.state.original_question}"
@@ -203,7 +209,10 @@ Remember: Your primary goal is answering the user's specific question, not expla
203
 
204
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
205
  answer_agent = Agent(
206
- f"{ctx.deps.config.llm.model_provider}:{ctx.deps.config.llm.model_name}",
 
 
 
207
  system_prompt=self.system_prompt
208
  )
209
 
@@ -267,6 +276,8 @@ async def process_chat(
267
  )
268
  except Exception as e:
269
  LOG.error(f"Error processing chat in graph: {e}")
 
 
270
  result = ChatResult(
271
  answer="I encountered an error processing your question. Please try again."
272
  )
 
12
  import logfire
13
  from pprint import pformat
14
  from enum import Enum
15
+ from rich.console import Console
16
+ from know_lang_bot.utils.model_provider import create_pydantic_model
17
 
18
  LOG = FancyLogger(__name__)
19
+ console = Console()
20
 
21
  class ChatStatus(str, Enum):
22
  """Enum for tracking chat progress status"""
 
131
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> RetrieveContextNode:
132
  # Create an agent for question polishing
133
  polish_agent = Agent(
134
+ create_pydantic_model(
135
+ model_provider=ctx.deps.config.llm.model_provider,
136
+ model_name=ctx.deps.config.llm.model_name
137
+ ),
138
  system_prompt=self.system_prompt
139
  )
140
  prompt = f"""Original question: "{ctx.state.original_question}"
 
209
 
210
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> End[ChatResult]:
211
  answer_agent = Agent(
212
+ create_pydantic_model(
213
+ model_provider=ctx.deps.config.llm.model_provider,
214
+ model_name=ctx.deps.config.llm.model_name
215
+ ),
216
  system_prompt=self.system_prompt
217
  )
218
 
 
276
  )
277
  except Exception as e:
278
  LOG.error(f"Error processing chat in graph: {e}")
279
+ console.print_exception()
280
+
281
  result = ChatResult(
282
  answer="I encountered an error processing your question. Please try again."
283
  )
src/know_lang_bot/models/huggingface.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime, timezone
6
+ from typing import AsyncIterator, Optional
7
+ from contextlib import asynccontextmanager
8
+
9
+ from pydantic_ai.models import (
10
+ AgentModel,
11
+ ModelMessage,
12
+ ModelResponse,
13
+ StreamedResponse,
14
+ Usage,
15
+ Model,
16
+ check_allow_model_requests
17
+ )
18
+ from pydantic_ai.tools import ToolDefinition
19
+ from pydantic_ai.messages import (
20
+ ModelResponsePart, TextPart, SystemPromptPart,
21
+ UserPromptPart, ToolReturnPart, ModelResponseStreamEvent
22
+ )
23
+ from pydantic_ai.settings import ModelSettings
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
25
+ import torch
26
+
27
+
28
+ @dataclass(init=False)
29
+ class HuggingFaceModel(Model):
30
+ """A model that uses HuggingFace models locally.
31
+
32
+ For MVP, this implements basic functionality without streaming or advanced features.
33
+ """
34
+
35
+ model_name: str
36
+ model: AutoModelForCausalLM = field(repr=False)
37
+ tokenizer: AutoTokenizer = field(repr=False)
38
+ device: str = field(default="cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ def __init__(
41
+ self,
42
+ model_name: str,
43
+ *,
44
+ device: Optional[str] = None,
45
+ max_new_tokens: int = 512,
46
+ ):
47
+ """Initialize a HuggingFace model.
48
+
49
+ Args:
50
+ model_name: Name of the model on HuggingFace Hub
51
+ device: Device to run model on ('cuda' or 'cpu')
52
+ max_new_tokens: Maximum number of tokens to generate
53
+ """
54
+ self.model_name = model_name
55
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
56
+ self.max_new_tokens = max_new_tokens
57
+
58
+ # Load model and tokenizer
59
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
60
+ self.model.to(self.device)
61
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+ if not self.tokenizer.pad_token:
63
+ self.tokenizer.pad_token = self.tokenizer.eos_token
64
+
65
+ async def agent_model(
66
+ self,
67
+ *,
68
+ function_tools: list[ToolDefinition],
69
+ allow_text_result: bool,
70
+ result_tools: list[ToolDefinition],
71
+ ) -> AgentModel:
72
+ """Create an agent model for each step of an agent run."""
73
+ check_allow_model_requests()
74
+
75
+ return HuggingFaceAgentModel(
76
+ model=self.model,
77
+ tokenizer=self.tokenizer,
78
+ model_name=self.model_name,
79
+ device=self.device,
80
+ max_new_tokens=self.max_new_tokens,
81
+ tools=function_tools if function_tools else None
82
+ )
83
+
84
+ def name(self) -> str:
85
+ return f"huggingface:{self.model_name}"
86
+
87
+ @dataclass
88
+ class HuggingFaceAgentModel(AgentModel):
89
+ """Implementation of AgentModel for HuggingFace models."""
90
+
91
+ model: AutoModelForCausalLM
92
+ tokenizer: AutoTokenizer
93
+ model_name: str
94
+ device: str
95
+ max_new_tokens: int
96
+ tools: Optional[list[ToolDefinition]] = None
97
+
98
+ def _format_messages(self, messages: list[ModelMessage]) -> str:
99
+ """Format messages into a prompt the model can understand."""
100
+ formatted = []
101
+ for message in messages:
102
+ for part in message.parts:
103
+ if isinstance(part, SystemPromptPart):
104
+ formatted.append(f"<|system|>{part.content}")
105
+ elif isinstance(part, UserPromptPart):
106
+ formatted.append(f"<|user|>{part.content}")
107
+ else:
108
+ # For MVP, we'll just pass through other message types
109
+ formatted.append(str(part.content))
110
+ formatted.append("<|assistant|>")
111
+ return "\n".join(formatted)
112
+
113
+ async def request(
114
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
115
+ ) -> tuple[ModelResponse, Usage]:
116
+ """Make a request to the model."""
117
+ prompt = self._format_messages(messages)
118
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
119
+
120
+ with torch.no_grad():
121
+ outputs = self.model.generate(
122
+ **inputs,
123
+ max_new_tokens=self.max_new_tokens,
124
+ pad_token_id=self.tokenizer.pad_token_id,
125
+ )
126
+
127
+ response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
128
+ response_text = response_text[len(prompt):] # Remove the prompt
129
+
130
+ timestamp = datetime.now(timezone.utc)
131
+ response = ModelResponse(
132
+ parts=[TextPart(response_text)],
133
+ model_name=self.model_name,
134
+ timestamp=timestamp
135
+ )
136
+
137
+ usage_stats = Usage(
138
+ requests=1,
139
+ request_tokens=inputs.input_ids.shape[1],
140
+ response_tokens=len(outputs[0]) - inputs.input_ids.shape[1],
141
+ total_tokens=len(outputs[0])
142
+ )
143
+
144
+ return response, usage_stats
145
+
146
+ @asynccontextmanager
147
+ async def request_stream(
148
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
149
+ ) -> AsyncIterator[StreamedResponse]:
150
+ """Make a streaming request to the model."""
151
+ prompt = self._format_messages(messages)
152
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
153
+
154
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
155
+
156
+ generation_kwargs = dict(
157
+ input_ids=inputs.input_ids,
158
+ max_new_tokens=self.max_new_tokens,
159
+ pad_token_id=self.tokenizer.pad_token_id,
160
+ streamer=streamer,
161
+ )
162
+
163
+ thread = torch.jit.fork(self.model.generate, **generation_kwargs)
164
+
165
+ try:
166
+ yield HuggingFaceStreamedResponse(
167
+ _model_name=self.model_name,
168
+ _streamer=streamer,
169
+ _timestamp=datetime.now(timezone.utc),
170
+ )
171
+ finally:
172
+ torch.jit.wait(thread)
173
+
174
+ @dataclass
175
+ class HuggingFaceStreamedResponse(StreamedResponse):
176
+ """Implementation of StreamedResponse for HuggingFace models."""
177
+
178
+ _streamer: TextIteratorStreamer
179
+ _timestamp: datetime
180
+
181
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
182
+ """Stream tokens from the model."""
183
+ for new_text in self._streamer:
184
+ self._usage.response_tokens += len(new_text)
185
+ self._usage.total_tokens = self._usage.request_tokens + self._usage.response_tokens
186
+
187
+ yield self._parts_manager.handle_text_delta(
188
+ vendor_part_id='content',
189
+ content=new_text
190
+ )
191
+
192
+ def timestamp(self) -> datetime:
193
+ return self._timestamp
src/know_lang_bot/summarizer/summarizer.py CHANGED
@@ -10,6 +10,7 @@ from rich.progress import Progress
10
  from know_lang_bot.config import AppConfig
11
  from know_lang_bot.core.types import CodeChunk, ModelProvider
12
  from know_lang_bot.utils.fancy_log import FancyLogger
 
13
 
14
  LOG = FancyLogger(__name__)
15
 
@@ -50,7 +51,10 @@ Provide a clean, concise and focused summary. Don't include unnecessary nor gene
50
  """
51
 
52
  self.agent = Agent(
53
- f"{self.config.llm.model_provider}:{self.config.llm.model_name}",
 
 
 
54
  system_prompt=system_prompt,
55
  model_settings=self.config.llm.model_settings
56
  )
 
10
  from know_lang_bot.config import AppConfig
11
  from know_lang_bot.core.types import CodeChunk, ModelProvider
12
  from know_lang_bot.utils.fancy_log import FancyLogger
13
+ from know_lang_bot.utils.model_provider import create_pydantic_model
14
 
15
  LOG = FancyLogger(__name__)
16
 
 
51
  """
52
 
53
  self.agent = Agent(
54
+ create_pydantic_model(
55
+ model_provider=self.config.llm.model_provider,
56
+ model_name=self.config.llm.model_name
57
+ ),
58
  system_prompt=system_prompt,
59
  model_settings=self.config.llm.model_settings
60
  )
src/know_lang_bot/utils/model_provider.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_ai.models import Model, KnownModelName
2
+ from know_lang_bot.core.types import ModelProvider
3
+ from know_lang_bot.models.huggingface import HuggingFaceModel
4
+ from typing import get_args
5
+
6
+ def create_pydantic_model(
7
+ model_provider: ModelProvider,
8
+ model_name: str,
9
+ ) -> Model | KnownModelName:
10
+ model_str = f"{model_provider}:{model_name}"
11
+
12
+ if model_str in get_args(KnownModelName):
13
+ return model_str
14
+ elif model_provider == ModelProvider.HUGGINGFACE:
15
+ return HuggingFaceModel(model_name=model_name)
16
+ else:
17
+ raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")