akashmishra358 commited on
Commit
8cc1ee4
·
verified ·
1 Parent(s): 1f2501b

Update model.py

Browse files

changed the previous did code with the new suggestion from deepseek about the env variables

Files changed (1) hide show
  1. model.py +64 -41
model.py CHANGED
@@ -4,77 +4,100 @@ from transformers import BitsAndBytesConfig
4
  from transformers.utils import is_flash_attn_2_available
5
  import yaml
6
  import torch
 
7
  import nltk
8
 
9
  def load_configs(config_file: str) -> dict:
10
  with open(config_file, "r") as f:
11
  configs = yaml.safe_load(f)
12
-
13
  return configs
14
 
15
-
16
  class RAGModel:
17
  def __init__(self, configs) -> None:
18
  self.configs = configs
19
- self.device = configs["model"]["device"]
20
- model_url = configs["model"]["generation_model"]
21
- # quantization_config = BitsAndBytesConfig(
22
- # load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
23
- # )
 
 
 
 
 
 
 
24
 
 
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  model_url,
 
27
  torch_dtype=torch.float16,
28
- # quantization_config=quantization_config,
29
  low_cpu_mem_usage=False,
30
  attn_implementation="sdpa",
31
- ).to(self.device)
 
 
32
  self.tokenizer = AutoTokenizer.from_pretrained(
33
  model_url,
 
34
  )
35
 
36
  def create_prompt(self, query, topk_items: list[str]):
37
-
38
- context = "\n-".join(c for c in topk_items)
39
-
40
- base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
41
- you have access to the internet and other relevent data related to the user's question.
42
- Give time for yourself to read the context and user query and extract relevent data and then answer the query.
43
- make sure your answers is as detailed as posssbile.
44
- Do not return thinking process, just return the answer.
45
- Give the output structured as a Wikipedia article.
46
- Now use the following context items to answer the user query
47
- context: {context}
48
- user query : {query}
 
49
  """
50
 
51
  dialog_template = [{"role": "user", "content": base_prompt}]
52
-
 
53
  prompt = self.tokenizer.apply_chat_template(
54
- conversation=dialog_template, tokenize=False, add_feneration_prompt=True
 
 
55
  )
56
  return prompt
57
 
58
  def answer_query(self, query: str, topk_items: list[str]):
59
-
60
  prompt = self.create_prompt(query, topk_items)
61
- input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
62
- output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
63
- text = self.tokenizer.decode(output[0])
64
- text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
65
-
66
-
 
 
 
 
 
 
 
 
 
 
 
 
67
  return text
68
 
69
  if __name__ == "__main__":
70
- configs = load_configs(config_file="rag.configs.yml")
71
- query = "The height of burj khalifa is 1000 meters and it was built in 2023. What is the height of burgj khalifa"
72
- # g = GoogleSearch(query)
73
- # data = g.all_page_data
74
- # d = Document(data, 512)
75
- # doc_chunks = d.doc()
76
- # s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
77
- # topk, u = s.semantic_search(query=query, k=32)
78
- r = RAGModel(configs)
79
- output = r.answer_query(query=query, topk_items=[""])
80
- print(output)
 
4
  from transformers.utils import is_flash_attn_2_available
5
  import yaml
6
  import torch
7
+ import os # Added for environment variables
8
  import nltk
9
 
10
  def load_configs(config_file: str) -> dict:
11
  with open(config_file, "r") as f:
12
  configs = yaml.safe_load(f)
 
13
  return configs
14
 
 
15
  class RAGModel:
16
  def __init__(self, configs) -> None:
17
  self.configs = configs
18
+
19
+ # 1. Get Hugging Face token (critical fix)
20
+ self.hf_token = os.getenv("HUGGINGFACE_TOKEN") or configs["model"].get("hf_token")
21
+ if not self.hf_token:
22
+ raise ValueError(
23
+ "Missing Hugging Face token! Set either:\n"
24
+ "1. HUGGINGFACE_TOKEN environment variable\n"
25
+ "2. hf_token in config.yml"
26
+ )
27
+
28
+ # 2. Fix model URL key (typo correction)
29
+ model_url = configs["model"]["generation_model"] # Fixed "genration_model" -> "generation_model"
30
 
31
+ # 3. Add authentication to model loading
32
  self.model = AutoModelForCausalLM.from_pretrained(
33
  model_url,
34
+ token=self.hf_token, # Added authentication
35
  torch_dtype=torch.float16,
 
36
  low_cpu_mem_usage=False,
37
  attn_implementation="sdpa",
38
+ device_map="auto" # Better device handling
39
+ )
40
+
41
  self.tokenizer = AutoTokenizer.from_pretrained(
42
  model_url,
43
+ token=self.hf_token # Added authentication
44
  )
45
 
46
  def create_prompt(self, query, topk_items: list[str]):
47
+ context = "\n-".join(c for c in topk_items)
48
+
49
+ # Improved prompt template
50
+ base_prompt = f"""You are an AI search assistant. Use this context to answer:
51
+ Context: {context}
52
+
53
+ Question: {query}
54
+
55
+ Answer in Wikipedia-style format with these requirements:
56
+ - Detailed technical explanations
57
+ - Historical context where relevant
58
+ - Numerical data when available
59
+ - Markdown formatting for structure
60
  """
61
 
62
  dialog_template = [{"role": "user", "content": base_prompt}]
63
+
64
+ # 4. Fix typo in apply_chat_template
65
  prompt = self.tokenizer.apply_chat_template(
66
+ conversation=dialog_template,
67
+ tokenize=False,
68
+ add_generation_prompt=True # Fixed "feneration" -> "generation"
69
  )
70
  return prompt
71
 
72
  def answer_query(self, query: str, topk_items: list[str]):
 
73
  prompt = self.create_prompt(query, topk_items)
74
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
75
+
76
+ # Improved generation parameters
77
+ output = self.model.generate(
78
+ **input_ids,
79
+ temperature=0.7,
80
+ max_new_tokens=1024,
81
+ do_sample=True,
82
+ top_p=0.9,
83
+ repetition_penalty=1.1
84
+ )
85
+
86
+ # Better text cleanup
87
+ text = self.tokenizer.decode(
88
+ output[0],
89
+ skip_special_tokens=True, # Better than manual replace
90
+ clean_up_tokenization_spaces=True
91
+ )
92
  return text
93
 
94
  if __name__ == "__main__":
95
+ # Test with authentication
96
+ configs = load_configs("rag.configs.yml")
97
+
98
+ # Add temporary token check
99
+ if "HUGGINGFACE_TOKEN" not in os.environ:
100
+ raise RuntimeError("Set HUGGINGFACE_TOKEN environment variable first!")
101
+
102
+ rag = RAGModel(configs)
103
+ print(rag.answer_query("What's the height of Burj Khalifa?", ["Burj Khalifa is 828 meters tall"]))