File size: 3,643 Bytes
23a1ec8
 
 
 
 
 
8cc1ee4
23a1ec8
 
 
 
 
 
 
 
 
 
8cc1ee4
 
 
 
 
 
 
 
 
 
 
 
23a1ec8
8cc1ee4
23a1ec8
 
8cc1ee4
23a1ec8
b8298df
23a1ec8
8cc1ee4
 
 
23a1ec8
 
8cc1ee4
23a1ec8
 
 
8cc1ee4
 
 
 
 
 
 
 
 
 
 
 
 
23a1ec8
 
 
8cc1ee4
 
23a1ec8
8cc1ee4
 
 
23a1ec8
 
 
 
 
8cc1ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23a1ec8
 
 
8cc1ee4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from search import SemanticSearch, GoogleSearch, Document
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available
import yaml
import torch
import os  # Added for environment variables
import nltk

def load_configs(config_file: str) -> dict:
    with open(config_file, "r") as f:
        configs = yaml.safe_load(f)
    return configs

class RAGModel:
    def __init__(self, configs) -> None:
        self.configs = configs
        
        # 1. Get Hugging Face token (critical fix)
        self.hf_token = os.getenv("HUGGINGFACE_TOKEN") or configs["model"].get("hf_token")
        if not self.hf_token:
            raise ValueError(
                "Missing Hugging Face token! Set either:\n"
                "1. HUGGINGFACE_TOKEN environment variable\n"
                "2. hf_token in config.yml"
            )

        # 2. Fix model URL key (typo correction)
        model_url = configs["model"]["generation_model"]  # Fixed "genration_model" -> "generation_model"

        # 3. Add authentication to model loading
        self.model = AutoModelForCausalLM.from_pretrained(
            model_url,
            token=self.hf_token,  # Added authentication
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            attn_implementation="sdpa",
            device_map="auto"  # Better device handling
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_url,
            token=self.hf_token  # Added authentication
        )

    def create_prompt(self, query, topk_items: list[str]):
        context = "\n-".join(c for c in topk_items)
        
        # Improved prompt template
        base_prompt = f"""You are an AI search assistant. Use this context to answer:
        Context: {context}
        
        Question: {query}
        
        Answer in Wikipedia-style format with these requirements:
        - Detailed technical explanations
        - Historical context where relevant
        - Numerical data when available
        - Markdown formatting for structure
        """

        dialog_template = [{"role": "user", "content": base_prompt}]
        
        # 4. Fix typo in apply_chat_template
        prompt = self.tokenizer.apply_chat_template(
            conversation=dialog_template, 
            tokenize=False, 
            add_generation_prompt=True  # Fixed "feneration" -> "generation"
        )
        return prompt

    def answer_query(self, query: str, topk_items: list[str]):
        prompt = self.create_prompt(query, topk_items)
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        # Improved generation parameters
        output = self.model.generate(
            **input_ids,
            temperature=0.7,
            max_new_tokens=1024,
            do_sample=True,
            top_p=0.9,
            repetition_penalty=1.1
        )
        
        # Better text cleanup
        text = self.tokenizer.decode(
            output[0], 
            skip_special_tokens=True,  # Better than manual replace
            clean_up_tokenization_spaces=True
        )
        return text

if __name__ == "__main__":
    # Test with authentication
    configs = load_configs("rag.configs.yml")
    
    # Add temporary token check
    if "HUGGINGFACE_TOKEN" not in os.environ:
        raise RuntimeError("Set HUGGINGFACE_TOKEN environment variable first!")
    
    rag = RAGModel(configs)
    print(rag.answer_query("What's the height of Burj Khalifa?", ["Burj Khalifa is 828 meters tall"]))