Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import asyncio | |
from dataclasses import dataclass, field | |
from datetime import datetime, timezone | |
from typing import AsyncIterator, Optional | |
from contextlib import asynccontextmanager | |
from pydantic_ai.models import ( | |
AgentModel, | |
ModelMessage, | |
ModelResponse, | |
StreamedResponse, | |
Usage, | |
Model, | |
check_allow_model_requests | |
) | |
from pydantic_ai.tools import ToolDefinition | |
from pydantic_ai.messages import ( | |
ModelResponsePart, TextPart, SystemPromptPart, | |
UserPromptPart, ToolReturnPart, ModelResponseStreamEvent | |
) | |
from pydantic_ai.settings import ModelSettings | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
class HuggingFaceModel(Model): | |
"""A model that uses HuggingFace models locally. | |
For MVP, this implements basic functionality without streaming or advanced features. | |
""" | |
model_name: str | |
model: AutoModelForCausalLM = field(repr=False) | |
tokenizer: AutoTokenizer = field(repr=False) | |
device: str = field(default="cuda" if torch.cuda.is_available() else "cpu") | |
def __init__( | |
self, | |
model_name: str, | |
*, | |
device: Optional[str] = None, | |
max_new_tokens: int = 512, | |
): | |
"""Initialize a HuggingFace model. | |
Args: | |
model_name: Name of the model on HuggingFace Hub | |
device: Device to run model on ('cuda' or 'cpu') | |
max_new_tokens: Maximum number of tokens to generate | |
""" | |
self.model_name = model_name | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.max_new_tokens = max_new_tokens | |
# Load model and tokenizer | |
self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
self.model.to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if not self.tokenizer.pad_token: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
async def agent_model( | |
self, | |
*, | |
function_tools: list[ToolDefinition], | |
allow_text_result: bool, | |
result_tools: list[ToolDefinition], | |
) -> AgentModel: | |
"""Create an agent model for each step of an agent run.""" | |
check_allow_model_requests() | |
return HuggingFaceAgentModel( | |
model=self.model, | |
tokenizer=self.tokenizer, | |
model_name=self.model_name, | |
device=self.device, | |
max_new_tokens=self.max_new_tokens, | |
tools=function_tools if function_tools else None | |
) | |
def name(self) -> str: | |
return f"huggingface:{self.model_name}" | |
class HuggingFaceAgentModel(AgentModel): | |
"""Implementation of AgentModel for HuggingFace models.""" | |
model: AutoModelForCausalLM | |
tokenizer: AutoTokenizer | |
model_name: str | |
device: str | |
max_new_tokens: int | |
tools: Optional[list[ToolDefinition]] = None | |
def _format_messages(self, messages: list[ModelMessage]) -> str: | |
"""Format messages into a prompt the model can understand.""" | |
formatted = [] | |
for message in messages: | |
for part in message.parts: | |
if isinstance(part, SystemPromptPart): | |
formatted.append(f"<|system|>{part.content}") | |
elif isinstance(part, UserPromptPart): | |
formatted.append(f"<|user|>{part.content}") | |
else: | |
# For MVP, we'll just pass through other message types | |
formatted.append(str(part.content)) | |
formatted.append("<|assistant|>") | |
return "\n".join(formatted) | |
async def request( | |
self, messages: list[ModelMessage], model_settings: ModelSettings | None | |
) -> tuple[ModelResponse, Usage]: | |
"""Make a request to the model.""" | |
prompt = self._format_messages(messages) | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.max_new_tokens, | |
pad_token_id=self.tokenizer.pad_token_id, | |
) | |
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response_text = response_text[len(prompt):] # Remove the prompt | |
timestamp = datetime.now(timezone.utc) | |
response = ModelResponse( | |
parts=[TextPart(response_text)], | |
model_name=self.model_name, | |
timestamp=timestamp | |
) | |
usage_stats = Usage( | |
requests=1, | |
request_tokens=inputs.input_ids.shape[1], | |
response_tokens=len(outputs[0]) - inputs.input_ids.shape[1], | |
total_tokens=len(outputs[0]) | |
) | |
return response, usage_stats | |
async def request_stream( | |
self, messages: list[ModelMessage], model_settings: ModelSettings | None | |
) -> AsyncIterator[StreamedResponse]: | |
"""Make a streaming request to the model.""" | |
prompt = self._format_messages(messages) | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) | |
generation_kwargs = dict( | |
input_ids=inputs.input_ids, | |
max_new_tokens=self.max_new_tokens, | |
pad_token_id=self.tokenizer.pad_token_id, | |
streamer=streamer, | |
) | |
thread = torch.jit.fork(self.model.generate, **generation_kwargs) | |
try: | |
yield HuggingFaceStreamedResponse( | |
_model_name=self.model_name, | |
_streamer=streamer, | |
_timestamp=datetime.now(timezone.utc), | |
) | |
finally: | |
torch.jit.wait(thread) | |
class HuggingFaceStreamedResponse(StreamedResponse): | |
"""Implementation of StreamedResponse for HuggingFace models.""" | |
_streamer: TextIteratorStreamer | |
_timestamp: datetime | |
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
"""Stream tokens from the model.""" | |
for new_text in self._streamer: | |
self._usage.response_tokens += len(new_text) | |
self._usage.total_tokens = self._usage.request_tokens + self._usage.response_tokens | |
yield self._parts_manager.handle_text_delta( | |
vendor_part_id='content', | |
content=new_text | |
) | |
def timestamp(self) -> datetime: | |
return self._timestamp |