from typing import Dict, List, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from fastapi.responses import StreamingResponse import uuid import time import json from threading import Thread class EndpointHandler: def __init__(self, path: str = "openai/gpt-oss-20b"): # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path) self.model.eval() # Determine the computation device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def openai_id(prefix: str) -> str: return f"{prefix}-{uuid.uuid4().hex[:24]}" def format_non_stream(self, model: str, text: str, prompt_length: int, completion_length: int, total_tokens: int): # Create OpenAI-compatible payload return { "id": self.openai_id("chatcmpl"), "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop" }], "usage": { "prompt_tokens": prompt_length, "completion_tokens": completion_length, "total_tokens": total_tokens } } def format_stream(self, model: str, token: str, usage) -> bytes: payload = { "id": self.openai_id("chatcmpl"), "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "delta": { "content": token, "function_call": None, "refusal": None, "role": None, "tool_calls": None }, "finish_reason": None, "logprobs": None }], "usage": usage } return f"data: {json.dumps(payload)}\n\n".encode('utf-8') def generate(self, messages, model: str): model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device) full_output = self.model.generate(**model_inputs, max_new_tokens=2048) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, full_output) ] text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] input_length = model_inputs.input_ids.shape[1] # Prompt tokens output_length = full_output.shape[1] # Total tokens (prompt + completion) completion_tokens = output_length - input_length return self.format_non_stream(model, text, input_length, completion_tokens, output_length) def stream(self, messages, model): model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device) input_len = model_inputs.input_ids.shape[1] streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( **model_inputs, streamer=streamer, max_new_tokens=2048 ) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() completion_tokens = 0 for token in streamer: # Count tokens in each chunk token_ids = self.tokenizer.encode(token, add_special_tokens=False) token_count = len(token_ids) completion_tokens += token_count yield self.format_stream(model, token, None) # Final chunk with stop reason and token counts yield self.format_stream(model, "", { "prompt_tokens": input_len, "completion_tokens": completion_tokens, "total_tokens": input_len + completion_tokens }) def __call__(self, data: Dict[str, Any]): messages = data.get("messages") model = data.get("model") stream = data.get("stream", False) if stream is False: return self.generate(messages, model) else: return StreamingResponse( self.stream(messages, model), media_type="text/event-stream" )