Spaces:
Runtime error
Runtime error
Update model.py
Browse filesback to old changes with just the word generation corrected
model.py
CHANGED
@@ -4,100 +4,77 @@ from transformers import BitsAndBytesConfig
|
|
4 |
from transformers.utils import is_flash_attn_2_available
|
5 |
import yaml
|
6 |
import torch
|
7 |
-
import os # Added for environment variables
|
8 |
import nltk
|
9 |
|
10 |
def load_configs(config_file: str) -> dict:
|
11 |
with open(config_file, "r") as f:
|
12 |
configs = yaml.safe_load(f)
|
|
|
13 |
return configs
|
14 |
|
|
|
15 |
class RAGModel:
|
16 |
def __init__(self, configs) -> None:
|
17 |
self.configs = configs
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
"Missing Hugging Face token! Set either:\n"
|
24 |
-
"1. HUGGINGFACE_TOKEN environment variable\n"
|
25 |
-
"2. hf_token in config.yml"
|
26 |
-
)
|
27 |
-
|
28 |
-
# 2. Fix model URL key (typo correction)
|
29 |
-
model_url = configs["model"]["generation_model"] # Fixed "genration_model" -> "generation_model"
|
30 |
|
31 |
-
# 3. Add authentication to model loading
|
32 |
self.model = AutoModelForCausalLM.from_pretrained(
|
33 |
model_url,
|
34 |
-
token=self.hf_token, # Added authentication
|
35 |
torch_dtype=torch.float16,
|
|
|
36 |
low_cpu_mem_usage=False,
|
37 |
attn_implementation="sdpa",
|
38 |
-
|
39 |
-
)
|
40 |
-
|
41 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
42 |
model_url,
|
43 |
-
token=self.hf_token # Added authentication
|
44 |
)
|
45 |
|
46 |
def create_prompt(self, query, topk_items: list[str]):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
base_prompt = f"""You are an
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
- Markdown formatting for structure
|
60 |
"""
|
61 |
|
62 |
dialog_template = [{"role": "user", "content": base_prompt}]
|
63 |
-
|
64 |
-
# 4. Fix typo in apply_chat_template
|
65 |
prompt = self.tokenizer.apply_chat_template(
|
66 |
-
conversation=dialog_template,
|
67 |
-
tokenize=False,
|
68 |
-
add_generation_prompt=True # Fixed "feneration" -> "generation"
|
69 |
)
|
70 |
return prompt
|
71 |
|
72 |
def answer_query(self, query: str, topk_items: list[str]):
|
|
|
73 |
prompt = self.create_prompt(query, topk_items)
|
74 |
-
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
max_new_tokens=1024,
|
81 |
-
do_sample=True,
|
82 |
-
top_p=0.9,
|
83 |
-
repetition_penalty=1.1
|
84 |
-
)
|
85 |
-
|
86 |
-
# Better text cleanup
|
87 |
-
text = self.tokenizer.decode(
|
88 |
-
output[0],
|
89 |
-
skip_special_tokens=True, # Better than manual replace
|
90 |
-
clean_up_tokenization_spaces=True
|
91 |
-
)
|
92 |
return text
|
93 |
|
94 |
if __name__ == "__main__":
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
4 |
from transformers.utils import is_flash_attn_2_available
|
5 |
import yaml
|
6 |
import torch
|
|
|
7 |
import nltk
|
8 |
|
9 |
def load_configs(config_file: str) -> dict:
|
10 |
with open(config_file, "r") as f:
|
11 |
configs = yaml.safe_load(f)
|
12 |
+
|
13 |
return configs
|
14 |
|
15 |
+
|
16 |
class RAGModel:
|
17 |
def __init__(self, configs) -> None:
|
18 |
self.configs = configs
|
19 |
+
self.device = configs["model"]["device"]
|
20 |
+
model_url = configs["model"]["generation_model"]
|
21 |
+
# quantization_config = BitsAndBytesConfig(
|
22 |
+
# load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
|
23 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
25 |
self.model = AutoModelForCausalLM.from_pretrained(
|
26 |
model_url,
|
|
|
27 |
torch_dtype=torch.float16,
|
28 |
+
# quantization_config=quantization_config,
|
29 |
low_cpu_mem_usage=False,
|
30 |
attn_implementation="sdpa",
|
31 |
+
).to(self.device)
|
|
|
|
|
32 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
33 |
model_url,
|
|
|
34 |
)
|
35 |
|
36 |
def create_prompt(self, query, topk_items: list[str]):
|
37 |
+
|
38 |
+
context = "\n-".join(c for c in topk_items)
|
39 |
+
|
40 |
+
base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
|
41 |
+
you have access to the internet and other relevent data related to the user's question.
|
42 |
+
Give time for yourself to read the context and user query and extract relevent data and then answer the query.
|
43 |
+
make sure your answers is as detailed as posssbile.
|
44 |
+
Do not return thinking process, just return the answer.
|
45 |
+
Give the output structured as a Wikipedia article.
|
46 |
+
Now use the following context items to answer the user query
|
47 |
+
context: {context}
|
48 |
+
user query : {query}
|
|
|
49 |
"""
|
50 |
|
51 |
dialog_template = [{"role": "user", "content": base_prompt}]
|
52 |
+
|
|
|
53 |
prompt = self.tokenizer.apply_chat_template(
|
54 |
+
conversation=dialog_template, tokenize=False, add_feneration_prompt=True
|
|
|
|
|
55 |
)
|
56 |
return prompt
|
57 |
|
58 |
def answer_query(self, query: str, topk_items: list[str]):
|
59 |
+
|
60 |
prompt = self.create_prompt(query, topk_items)
|
61 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
62 |
+
output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
|
63 |
+
text = self.tokenizer.decode(output[0])
|
64 |
+
text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
|
65 |
+
|
66 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
return text
|
68 |
|
69 |
if __name__ == "__main__":
|
70 |
+
configs = load_configs(config_file="rag.configs.yml")
|
71 |
+
query = "The height of burj khalifa is 1000 meters and it was built in 2023. What is the height of burgj khalifa"
|
72 |
+
# g = GoogleSearch(query)
|
73 |
+
# data = g.all_page_data
|
74 |
+
# d = Document(data, 512)
|
75 |
+
# doc_chunks = d.doc()
|
76 |
+
# s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
|
77 |
+
# topk, u = s.semantic_search(query=query, k=32)
|
78 |
+
r = RAGModel(configs)
|
79 |
+
output = r.answer_query(query=query, topk_items=[""])
|
80 |
+
print(output)
|