File size: 8,485 Bytes
472f1d2
b976908
472f1d2
 
6529956
 
472f1d2
 
 
6529956
 
472f1d2
b976908
6529956
472f1d2
b976908
 
472f1d2
b976908
 
6529956
 
 
 
472f1d2
b976908
472f1d2
6529956
 
 
 
ff78fc6
6529956
 
 
 
 
 
ff78fc6
6529956
 
 
 
 
ff78fc6
6529956
b976908
6529956
ff78fc6
6529956
 
 
ff78fc6
6529956
ff78fc6
6529956
ff78fc6
 
 
 
 
 
 
 
 
6529956
ff78fc6
 
6529956
ff78fc6
6529956
ff78fc6
 
 
 
6529956
ff78fc6
 
 
6529956
ff78fc6
 
 
 
 
6529956
ff78fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6529956
472f1d2
 
 
b976908
472f1d2
 
b976908
 
 
472f1d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
from typing import Dict, Any
import yaml
import os
from models import ModernBertForSentiment

class SentimentInference:
    def __init__(self, config_path: str = "config.yaml"):
        """Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
        print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
        with open(config_path, 'r') as f:
            config_data = yaml.safe_load(f)
        print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
        
        model_yaml_cfg = config_data.get('model', {})
        inference_yaml_cfg = config_data.get('inference', {})
        
        model_hf_repo_id = model_yaml_cfg.get('name_or_path')
        tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
        local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file

        print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
        print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this

        self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))

        # --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
        if not tokenizer_hf_repo_id and not model_hf_repo_id:
            raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
        effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
        print(f"[INFERENCE_LOG] Loading tokenizer from: {effective_tokenizer_repo_id}") # Logging
        self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)

        # --- Model Loading --- #
        # Determine if we are loading from a local .pt file or from Hugging Face Hub
        load_from_local_pt = False
        if local_model_weights_path and os.path.isfile(local_model_weights_path):
            print(f"[INFERENCE_LOG] Found local model weights path: {local_model_weights_path}") # Logging
            print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
            load_from_local_pt = True
        elif not model_hf_repo_id:
            raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")

        print(f"[INFERENCE_LOG] load_from_local_pt: {load_from_local_pt}") # Logging
        print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this

        if load_from_local_pt:
            print("[INFERENCE_LOG] Attempting to load model from LOCAL .pt checkpoint...") # Logging
            print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
            # Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
            # This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
            base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_yaml_cfg.get('name_or_path'))
            if not base_model_for_config_id:
                 raise ValueError("model.base_model_for_config or model.name_or_path must be specified in config.yaml when loading local .pt for ModernBertForSentiment structure.")
            
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: base_model_for_config_id: {base_model_for_config_id}") # Logging

            model_config = ModernBertConfig.from_pretrained(
                base_model_for_config_id, 
                num_labels=model_yaml_cfg.get('num_labels', 1), # from config.yaml via model_yaml_cfg
                pooling_strategy=model_yaml_cfg.get('pooling_strategy', 'mean'), # from config.yaml via model_yaml_cfg
                num_weighted_layers=model_yaml_cfg.get('num_weighted_layers', 4) # from config.yaml via model_yaml_cfg
            )
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loaded ModernBertConfig: {model_config.to_diff_dict()}") # Logging

            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Initializing ModernBertForSentiment with this config.") # Logging
            self.model = ModernBertForSentiment(config=model_config)
            
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loading weights from checkpoint: {local_model_weights_path}") # Logging
            checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
            
            state_dict_to_load = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
            if not isinstance(state_dict_to_load, dict):
                raise TypeError(f"Loaded checkpoint from {local_model_weights_path} is not a dict or does not contain 'model_state_dict' or 'state_dict'.")

            # Log first few keys for debugging
            first_few_keys = list(state_dict_to_load.keys())[:5]
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: First few keys from checkpoint state_dict: {first_few_keys}") # Logging

            self.model.load_state_dict(state_dict_to_load)
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Weights loaded successfully into ModernBertForSentiment from {local_model_weights_path}.") # Logging
        else:
            # Load from Hugging Face Hub
            print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
            
            # Here, we use the config that's packaged with the model on the Hub by default.
            # We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
            # as these might be specific to our fine-tuning and not in the Hub's default config.json.
            hub_config_overrides = {
                "num_labels": model_yaml_cfg.get('num_labels', 1),
                "pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
                "num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
            }
            print(f"[INFERENCE_LOG] HUB_LOAD: Overrides for Hub config: {hub_config_overrides}") # Logging

            try:
                # Using ModernBertForSentiment.from_pretrained directly.
                # This assumes the config.json on the Hub for 'model_hf_repo_id' is compatible
                # or that from_pretrained can correctly initialize ModernBertForSentiment with it.
                self.model = ModernBertForSentiment.from_pretrained(
                    model_hf_repo_id,
                    **hub_config_overrides
                )
                print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
            except Exception as e:
                print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
                print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
                # Fallback: Try with AutoModelForSequenceClassification if ModernBertForSentiment fails
                # This might happen if the Hub model isn't strictly saved as a ModernBertForSentiment type
                # or if its config.json doesn't have _custom_class set, etc.
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    model_hf_repo_id,
                    **hub_config_overrides
                )
                print(f"[INFERENCE_LOG] HUB_LOAD: AutoModelForSequenceClassification loaded for {model_hf_repo_id}.") # Logging

        self.model.eval()
        
    def predict(self, text: str) -> Dict[str, Any]:
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
        with torch.no_grad():
            outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        logits = outputs.get("logits") # Use .get for safety
        if logits is None:
            raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
        prob = torch.sigmoid(logits).item()
        return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}