gabykim's picture
refactor package name to knowlang
60532a1
raw
history blame
6.72 kB
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import AsyncIterator, Optional
from contextlib import asynccontextmanager
from pydantic_ai.models import (
AgentModel,
ModelMessage,
ModelResponse,
StreamedResponse,
Usage,
Model,
check_allow_model_requests
)
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.messages import (
ModelResponsePart, TextPart, SystemPromptPart,
UserPromptPart, ToolReturnPart, ModelResponseStreamEvent
)
from pydantic_ai.settings import ModelSettings
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
@dataclass(init=False)
class HuggingFaceModel(Model):
"""A model that uses HuggingFace models locally.
For MVP, this implements basic functionality without streaming or advanced features.
"""
model_name: str
model: AutoModelForCausalLM = field(repr=False)
tokenizer: AutoTokenizer = field(repr=False)
device: str = field(default="cuda" if torch.cuda.is_available() else "cpu")
def __init__(
self,
model_name: str,
*,
device: Optional[str] = None,
max_new_tokens: int = 512,
):
"""Initialize a HuggingFace model.
Args:
model_name: Name of the model on HuggingFace Hub
device: Device to run model on ('cuda' or 'cpu')
max_new_tokens: Maximum number of tokens to generate
"""
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.max_new_tokens = max_new_tokens
# Load model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
"""Create an agent model for each step of an agent run."""
check_allow_model_requests()
return HuggingFaceAgentModel(
model=self.model,
tokenizer=self.tokenizer,
model_name=self.model_name,
device=self.device,
max_new_tokens=self.max_new_tokens,
tools=function_tools if function_tools else None
)
def name(self) -> str:
return f"huggingface:{self.model_name}"
@dataclass
class HuggingFaceAgentModel(AgentModel):
"""Implementation of AgentModel for HuggingFace models."""
model: AutoModelForCausalLM
tokenizer: AutoTokenizer
model_name: str
device: str
max_new_tokens: int
tools: Optional[list[ToolDefinition]] = None
def _format_messages(self, messages: list[ModelMessage]) -> str:
"""Format messages into a prompt the model can understand."""
formatted = []
for message in messages:
for part in message.parts:
if isinstance(part, SystemPromptPart):
formatted.append(f"<|system|>{part.content}")
elif isinstance(part, UserPromptPart):
formatted.append(f"<|user|>{part.content}")
else:
# For MVP, we'll just pass through other message types
formatted.append(str(part.content))
formatted.append("<|assistant|>")
return "\n".join(formatted)
async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, Usage]:
"""Make a request to the model."""
prompt = self._format_messages(messages)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
pad_token_id=self.tokenizer.pad_token_id,
)
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response_text = response_text[len(prompt):] # Remove the prompt
timestamp = datetime.now(timezone.utc)
response = ModelResponse(
parts=[TextPart(response_text)],
model_name=self.model_name,
timestamp=timestamp
)
usage_stats = Usage(
requests=1,
request_tokens=inputs.input_ids.shape[1],
response_tokens=len(outputs[0]) - inputs.input_ids.shape[1],
total_tokens=len(outputs[0])
)
return response, usage_stats
@asynccontextmanager
async def request_stream(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[StreamedResponse]:
"""Make a streaming request to the model."""
prompt = self._format_messages(messages)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=self.max_new_tokens,
pad_token_id=self.tokenizer.pad_token_id,
streamer=streamer,
)
thread = torch.jit.fork(self.model.generate, **generation_kwargs)
try:
yield HuggingFaceStreamedResponse(
_model_name=self.model_name,
_streamer=streamer,
_timestamp=datetime.now(timezone.utc),
)
finally:
torch.jit.wait(thread)
@dataclass
class HuggingFaceStreamedResponse(StreamedResponse):
"""Implementation of StreamedResponse for HuggingFace models."""
_streamer: TextIteratorStreamer
_timestamp: datetime
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
"""Stream tokens from the model."""
for new_text in self._streamer:
self._usage.response_tokens += len(new_text)
self._usage.total_tokens = self._usage.request_tokens + self._usage.response_tokens
yield self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=new_text
)
def timestamp(self) -> datetime:
return self._timestamp