akashmishra358 commited on
Commit
fabdf09
·
verified ·
1 Parent(s): 9a85b0b

Update model.py

Browse files

back to old changes with just the word generation corrected

Files changed (1) hide show
  1. model.py +41 -64
model.py CHANGED
@@ -4,100 +4,77 @@ from transformers import BitsAndBytesConfig
4
  from transformers.utils import is_flash_attn_2_available
5
  import yaml
6
  import torch
7
- import os # Added for environment variables
8
  import nltk
9
 
10
  def load_configs(config_file: str) -> dict:
11
  with open(config_file, "r") as f:
12
  configs = yaml.safe_load(f)
 
13
  return configs
14
 
 
15
  class RAGModel:
16
  def __init__(self, configs) -> None:
17
  self.configs = configs
18
-
19
- # 1. Get Hugging Face token (critical fix)
20
- self.hf_token = os.getenv("HUGGINGFACE_TOKEN") or configs["model"].get("hf_token")
21
- if not self.hf_token:
22
- raise ValueError(
23
- "Missing Hugging Face token! Set either:\n"
24
- "1. HUGGINGFACE_TOKEN environment variable\n"
25
- "2. hf_token in config.yml"
26
- )
27
-
28
- # 2. Fix model URL key (typo correction)
29
- model_url = configs["model"]["generation_model"] # Fixed "genration_model" -> "generation_model"
30
 
31
- # 3. Add authentication to model loading
32
  self.model = AutoModelForCausalLM.from_pretrained(
33
  model_url,
34
- token=self.hf_token, # Added authentication
35
  torch_dtype=torch.float16,
 
36
  low_cpu_mem_usage=False,
37
  attn_implementation="sdpa",
38
- device_map="auto" # Better device handling
39
- )
40
-
41
  self.tokenizer = AutoTokenizer.from_pretrained(
42
  model_url,
43
- token=self.hf_token # Added authentication
44
  )
45
 
46
  def create_prompt(self, query, topk_items: list[str]):
47
- context = "\n-".join(c for c in topk_items)
48
-
49
- # Improved prompt template
50
- base_prompt = f"""You are an AI search assistant. Use this context to answer:
51
- Context: {context}
52
-
53
- Question: {query}
54
-
55
- Answer in Wikipedia-style format with these requirements:
56
- - Detailed technical explanations
57
- - Historical context where relevant
58
- - Numerical data when available
59
- - Markdown formatting for structure
60
  """
61
 
62
  dialog_template = [{"role": "user", "content": base_prompt}]
63
-
64
- # 4. Fix typo in apply_chat_template
65
  prompt = self.tokenizer.apply_chat_template(
66
- conversation=dialog_template,
67
- tokenize=False,
68
- add_generation_prompt=True # Fixed "feneration" -> "generation"
69
  )
70
  return prompt
71
 
72
  def answer_query(self, query: str, topk_items: list[str]):
 
73
  prompt = self.create_prompt(query, topk_items)
74
- input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
75
-
76
- # Improved generation parameters
77
- output = self.model.generate(
78
- **input_ids,
79
- temperature=0.7,
80
- max_new_tokens=1024,
81
- do_sample=True,
82
- top_p=0.9,
83
- repetition_penalty=1.1
84
- )
85
-
86
- # Better text cleanup
87
- text = self.tokenizer.decode(
88
- output[0],
89
- skip_special_tokens=True, # Better than manual replace
90
- clean_up_tokenization_spaces=True
91
- )
92
  return text
93
 
94
  if __name__ == "__main__":
95
- # Test with authentication
96
- configs = load_configs("rag.configs.yml")
97
-
98
- # Add temporary token check
99
- if "HUGGINGFACE_TOKEN" not in os.environ:
100
- raise RuntimeError("Set HUGGINGFACE_TOKEN environment variable first!")
101
-
102
- rag = RAGModel(configs)
103
- print(rag.answer_query("What's the height of Burj Khalifa?", ["Burj Khalifa is 828 meters tall"]))
 
 
 
4
  from transformers.utils import is_flash_attn_2_available
5
  import yaml
6
  import torch
 
7
  import nltk
8
 
9
  def load_configs(config_file: str) -> dict:
10
  with open(config_file, "r") as f:
11
  configs = yaml.safe_load(f)
12
+
13
  return configs
14
 
15
+
16
  class RAGModel:
17
  def __init__(self, configs) -> None:
18
  self.configs = configs
19
+ self.device = configs["model"]["device"]
20
+ model_url = configs["model"]["generation_model"]
21
+ # quantization_config = BitsAndBytesConfig(
22
+ # load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
23
+ # )
 
 
 
 
 
 
 
24
 
 
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  model_url,
 
27
  torch_dtype=torch.float16,
28
+ # quantization_config=quantization_config,
29
  low_cpu_mem_usage=False,
30
  attn_implementation="sdpa",
31
+ ).to(self.device)
 
 
32
  self.tokenizer = AutoTokenizer.from_pretrained(
33
  model_url,
 
34
  )
35
 
36
  def create_prompt(self, query, topk_items: list[str]):
37
+
38
+ context = "\n-".join(c for c in topk_items)
39
+
40
+ base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
41
+ you have access to the internet and other relevent data related to the user's question.
42
+ Give time for yourself to read the context and user query and extract relevent data and then answer the query.
43
+ make sure your answers is as detailed as posssbile.
44
+ Do not return thinking process, just return the answer.
45
+ Give the output structured as a Wikipedia article.
46
+ Now use the following context items to answer the user query
47
+ context: {context}
48
+ user query : {query}
 
49
  """
50
 
51
  dialog_template = [{"role": "user", "content": base_prompt}]
52
+
 
53
  prompt = self.tokenizer.apply_chat_template(
54
+ conversation=dialog_template, tokenize=False, add_feneration_prompt=True
 
 
55
  )
56
  return prompt
57
 
58
  def answer_query(self, query: str, topk_items: list[str]):
59
+
60
  prompt = self.create_prompt(query, topk_items)
61
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
62
+ output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
63
+ text = self.tokenizer.decode(output[0])
64
+ text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
65
+
66
+
 
 
 
 
 
 
 
 
 
 
 
 
67
  return text
68
 
69
  if __name__ == "__main__":
70
+ configs = load_configs(config_file="rag.configs.yml")
71
+ query = "The height of burj khalifa is 1000 meters and it was built in 2023. What is the height of burgj khalifa"
72
+ # g = GoogleSearch(query)
73
+ # data = g.all_page_data
74
+ # d = Document(data, 512)
75
+ # doc_chunks = d.doc()
76
+ # s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
77
+ # topk, u = s.semantic_search(query=query, k=32)
78
+ r = RAGModel(configs)
79
+ output = r.answer_query(query=query, topk_items=[""])
80
+ print(output)