from sentence_transformers import CrossEncoder as _CE import math from typing import cast import types import torch from transformers.configuration_utils import PretrainedConfig from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.gemma3.modeling_gemma3 import ( Gemma3ForCausalLM, Gemma3ForConditionalGeneration, ) from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM from transformers.tokenization_utils_base import BatchEncoding from transformers.tokenization_utils_fast import PreTrainedTokenizerFast # pyright: reportUnknownMemberType=false # pyright: reportUnknownVariableType=false MODEL_PATH = "zeroentropy/ze-rerank-large-v0.3.0" PER_DEVICE_BATCH_SIZE_TOKENS = 15_000 def format_pointwise_datapoints( tokenizer: PreTrainedTokenizerFast, query_documents: list[tuple[str, str]], ) -> BatchEncoding: input_texts: list[str] = [] for query, document in query_documents: system_prompt = f""" {query} """.strip() user_message = f""" {document} """.strip() messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ] input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) assert isinstance(input_text, str) input_texts.append(input_text) batch_inputs = tokenizer( input_texts, padding=True, return_tensors="pt", ) return batch_inputs def load_model( device: torch.device | None = None, ) -> tuple[ PreTrainedTokenizerFast, LlamaForCausalLM | Gemma3ForConditionalGeneration | Gemma3ForCausalLM | Qwen3ForCausalLM, ]: if device is None: device = torch.device("cpu") config = AutoConfig.from_pretrained(MODEL_PATH) assert isinstance(config, PretrainedConfig) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype="auto", quantization_config=None, device_map={"": device}, ) if config.model_type == "llama": model.config.attn_implementation = "flash_attention_2" print(f"Model Type: {config.model_type}") assert isinstance( model, LlamaForCausalLM | Gemma3ForConditionalGeneration | Gemma3ForCausalLM | Qwen3ForCausalLM, ) tokenizer = cast( AutoTokenizer, AutoTokenizer.from_pretrained( MODEL_PATH, padding_side="right", ), ) assert isinstance(tokenizer, PreTrainedTokenizerFast) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer, model def predict(self, query_documents: list[tuple[str, str]]) -> list[float]: if not hasattr(self, "inner_model"): self.inner_tokenizer, self.inner_model = load_model(torch.device("cuda")) self.inner_model.gradient_checkpointing_enable() self.inner_model.eval() self.inner_yes_token_id = self.inner_tokenizer.encode("Yes", add_special_tokens=False)[0] print("patched") model = self.inner_model tokenizer = self.inner_tokenizer query_documents = [ (query[:2_000], document[:10_000]) for query, document in query_documents ] # Sort permutation = list(range(len(query_documents))) permutation.sort(key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])) query_documents = [query_documents[i] for i in permutation] device = torch.device("cuda") # Extract document batches from this line of datapoints max_length = 0 batches: list[list[tuple[str, str]]] = [] for query, document in query_documents: if ( len(batches) == 0 or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document)) > PER_DEVICE_BATCH_SIZE_TOKENS ): batches.append([]) max_length = 0 batches[-1].append((query, document)) max_length = max(max_length, 20 + len(query) + len(document)) # Inference all of the document batches all_logits: list[float] = [] for batch in batches: batch_inputs = format_pointwise_datapoints( tokenizer, batch, ) batch_inputs = batch_inputs.to(device) try: outputs = model(**batch_inputs, use_cache=False) except torch.OutOfMemoryError: print(f"GPU OOM! {torch.cuda.memory_reserved()}") torch.cuda.empty_cache() print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}") outputs = model(**batch_inputs, use_cache=False) # Extract the logits logits = cast(torch.Tensor, outputs.logits) attention_mask = cast(torch.Tensor, batch_inputs.attention_mask) last_positions = attention_mask.sum(dim=1) - 1 batch_size = logits.shape[0] batch_indices = torch.arange(batch_size, device=device) last_logits = logits[batch_indices, last_positions] yes_logits = last_logits[:, self.inner_yes_token_id] all_logits.extend([float(logit) / 5.0 for logit in yes_logits]) def sigmoid(x: float) -> float: return 1 / (1 + math.exp(-x)) scores = [sigmoid(logit) for logit in all_logits] # Unsort by indices scores = [score for _, score in sorted(zip(permutation, scores, strict=True))] return scores _CE.predict = predict from transformers import Qwen3Config ZEConfig = Qwen3Config