from search import SemanticSearch, GoogleSearch, Document from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import BitsAndBytesConfig from transformers.utils import is_flash_attn_2_available import yaml import torch import os # Added for environment variables import nltk def load_configs(config_file: str) -> dict: with open(config_file, "r") as f: configs = yaml.safe_load(f) return configs class RAGModel: def __init__(self, configs) -> None: self.configs = configs # 1. Get Hugging Face token (critical fix) self.hf_token = os.getenv("HUGGINGFACE_TOKEN") or configs["model"].get("hf_token") if not self.hf_token: raise ValueError( "Missing Hugging Face token! Set either:\n" "1. HUGGINGFACE_TOKEN environment variable\n" "2. hf_token in config.yml" ) # 2. Fix model URL key (typo correction) model_url = configs["model"]["generation_model"] # Fixed "genration_model" -> "generation_model" # 3. Add authentication to model loading self.model = AutoModelForCausalLM.from_pretrained( model_url, token=self.hf_token, # Added authentication torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="sdpa", device_map="auto" # Better device handling ) self.tokenizer = AutoTokenizer.from_pretrained( model_url, token=self.hf_token # Added authentication ) def create_prompt(self, query, topk_items: list[str]): context = "\n-".join(c for c in topk_items) # Improved prompt template base_prompt = f"""You are an AI search assistant. Use this context to answer: Context: {context} Question: {query} Answer in Wikipedia-style format with these requirements: - Detailed technical explanations - Historical context where relevant - Numerical data when available - Markdown formatting for structure """ dialog_template = [{"role": "user", "content": base_prompt}] # 4. Fix typo in apply_chat_template prompt = self.tokenizer.apply_chat_template( conversation=dialog_template, tokenize=False, add_generation_prompt=True # Fixed "feneration" -> "generation" ) return prompt def answer_query(self, query: str, topk_items: list[str]): prompt = self.create_prompt(query, topk_items) input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Improved generation parameters output = self.model.generate( **input_ids, temperature=0.7, max_new_tokens=1024, do_sample=True, top_p=0.9, repetition_penalty=1.1 ) # Better text cleanup text = self.tokenizer.decode( output[0], skip_special_tokens=True, # Better than manual replace clean_up_tokenization_spaces=True ) return text if __name__ == "__main__": # Test with authentication configs = load_configs("rag.configs.yml") # Add temporary token check if "HUGGINGFACE_TOKEN" not in os.environ: raise RuntimeError("Set HUGGINGFACE_TOKEN environment variable first!") rag = RAGModel(configs) print(rag.answer_query("What's the height of Burj Khalifa?", ["Burj Khalifa is 828 meters tall"]))