File size: 6,720 Bytes
183e719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import AsyncIterator, Optional
from contextlib import asynccontextmanager

from pydantic_ai.models import (
    AgentModel, 
    ModelMessage, 
    ModelResponse, 
    StreamedResponse, 
    Usage, 
    Model,
    check_allow_model_requests
)
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.messages import (
    ModelResponsePart, TextPart,  SystemPromptPart,
    UserPromptPart, ToolReturnPart, ModelResponseStreamEvent
)
from pydantic_ai.settings import ModelSettings
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch


@dataclass(init=False)
class HuggingFaceModel(Model):
    """A model that uses HuggingFace models locally.
    
    For MVP, this implements basic functionality without streaming or advanced features.
    """
    
    model_name: str
    model: AutoModelForCausalLM = field(repr=False)
    tokenizer: AutoTokenizer = field(repr=False)
    device: str = field(default="cuda" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        model_name: str,
        *,
        device: Optional[str] = None,
        max_new_tokens: int = 512,
    ):
        """Initialize a HuggingFace model.
        
        Args:
            model_name: Name of the model on HuggingFace Hub
            device: Device to run model on ('cuda' or 'cpu')
            max_new_tokens: Maximum number of tokens to generate
        """
        self.model_name = model_name
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.max_new_tokens = max_new_tokens
        
        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.model.to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    async def agent_model(
        self,
        *,
        function_tools: list[ToolDefinition],
        allow_text_result: bool,
        result_tools: list[ToolDefinition],
    ) -> AgentModel:
        """Create an agent model for each step of an agent run."""
        check_allow_model_requests()
        
        return HuggingFaceAgentModel(
            model=self.model,
            tokenizer=self.tokenizer,
            model_name=self.model_name,
            device=self.device,
            max_new_tokens=self.max_new_tokens,
            tools=function_tools if function_tools else None
        )

    def name(self) -> str:
        return f"huggingface:{self.model_name}"

@dataclass
class HuggingFaceAgentModel(AgentModel):
    """Implementation of AgentModel for HuggingFace models."""
    
    model: AutoModelForCausalLM
    tokenizer: AutoTokenizer
    model_name: str
    device: str
    max_new_tokens: int
    tools: Optional[list[ToolDefinition]] = None
    
    def _format_messages(self, messages: list[ModelMessage]) -> str:
        """Format messages into a prompt the model can understand."""
        formatted = []
        for message in messages:
            for part in message.parts:
                if isinstance(part, SystemPromptPart):
                    formatted.append(f"<|system|>{part.content}")
                elif isinstance(part, UserPromptPart):
                    formatted.append(f"<|user|>{part.content}")
                else:
                    # For MVP, we'll just pass through other message types
                    formatted.append(str(part.content))
        formatted.append("<|assistant|>")
        return "\n".join(formatted)

    async def request(
        self, messages: list[ModelMessage], model_settings: ModelSettings | None
    ) -> tuple[ModelResponse, Usage]:
        """Make a request to the model."""
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response_text = response_text[len(prompt):]  # Remove the prompt
        
        timestamp = datetime.now(timezone.utc)
        response = ModelResponse(
            parts=[TextPart(response_text)],
            model_name=self.model_name,
            timestamp=timestamp
        )
        
        usage_stats = Usage(
            requests=1,
            request_tokens=inputs.input_ids.shape[1],
            response_tokens=len(outputs[0]) - inputs.input_ids.shape[1],
            total_tokens=len(outputs[0])
        )
        
        return response, usage_stats

    @asynccontextmanager
    async def request_stream(
        self, messages: list[ModelMessage], model_settings: ModelSettings | None
    ) -> AsyncIterator[StreamedResponse]:
        """Make a streaming request to the model."""
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        
        generation_kwargs = dict(
            input_ids=inputs.input_ids,
            max_new_tokens=self.max_new_tokens,
            pad_token_id=self.tokenizer.pad_token_id,
            streamer=streamer,
        )
        
        thread = torch.jit.fork(self.model.generate, **generation_kwargs)
        
        try:
            yield HuggingFaceStreamedResponse(
                _model_name=self.model_name,
                _streamer=streamer,
                _timestamp=datetime.now(timezone.utc),
            )
        finally:
            torch.jit.wait(thread)

@dataclass 
class HuggingFaceStreamedResponse(StreamedResponse):
    """Implementation of StreamedResponse for HuggingFace models."""
    
    _streamer: TextIteratorStreamer
    _timestamp: datetime

    async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
        """Stream tokens from the model."""
        for new_text in self._streamer:
            self._usage.response_tokens += len(new_text)
            self._usage.total_tokens = self._usage.request_tokens + self._usage.response_tokens
            
            yield self._parts_manager.handle_text_delta(
                vendor_part_id='content',
                content=new_text
            )

    def timestamp(self) -> datetime:
        return self._timestamp