akashmishra358 commited on
Commit
23a1ec8
·
verified ·
1 Parent(s): c67c488

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +3 -8
  2. __init__.py +1 -0
  3. app.py +60 -0
  4. model.py +80 -0
  5. rag.configs.yml +7 -0
  6. requirements.txt +178 -0
  7. search.py +148 -0
README.md CHANGED
@@ -1,13 +1,8 @@
1
  ---
2
- title: NBPlatina
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: streamlit
7
- sdk_version: 1.31.1
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
 
 
 
 
2
  sdk: streamlit
3
+ sdk_version: 1.33.0
4
  app_file: app.py
5
+ licese: mit
 
6
  ---
7
 
8
+ install nltk.download("punkt")
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import search
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from search import SemanticSearch, GoogleSearch, Document
2
+ import streamlit as st
3
+ from model import RAGModel, load_configs
4
+
5
+
6
+ def run_on_start():
7
+
8
+ if "configs" not in st.session_state:
9
+ st.session_state.configs = configs = load_configs(config_file="rag.configs.yml")
10
+ if "model" not in st.session_state:
11
+ st.session_state.model = RAGModel(configs)
12
+
13
+ run_on_start()
14
+
15
+
16
+ def search(query):
17
+ g = GoogleSearch(query)
18
+ data = g.all_page_data
19
+ d = Document(data, min_char_len=st.session_state.configs["document"]["min_char_length"])
20
+ st.session_state.doc = d.doc()
21
+
22
+
23
+ st.title("Search Here Instead of Google")
24
+
25
+ if "messages" not in st.session_state:
26
+ st.session_state.messages = []
27
+
28
+ if "doc" not in st.session_state:
29
+ st.session_state.doc = None
30
+
31
+ if "refresh" not in st.session_state:
32
+ st.session_state.refresh = True
33
+
34
+ for message in st.session_state.messages:
35
+ with st.chat_message(message["role"]):
36
+ st.markdown(message["content"])
37
+
38
+
39
+ if prompt := st.chat_input("Search Here insetad of Google"):
40
+ st.chat_message("user").markdown(prompt)
41
+ st.session_state.messages.append({"role": "user", "content": prompt})
42
+
43
+ configs = st.session_state.configs
44
+ if st.session_state.refresh:
45
+ st.session_state.refresh = False
46
+ search(prompt)
47
+
48
+ s = SemanticSearch(
49
+ st.session_state.doc,
50
+ configs["model"]["embeding_model"],
51
+ configs["model"]["device"],
52
+ )
53
+ topk, u = s.semantic_search(query=prompt, k=32)
54
+ output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
55
+ response = output
56
+ with st.chat_message("assistant"):
57
+ st.markdown(response)
58
+
59
+ st.session_state.messages.append({"role": "assistant", "content": response})
60
+
model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from search import SemanticSearch, GoogleSearch, Document
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers import BitsAndBytesConfig
4
+ from transformers.utils import is_flash_attn_2_available
5
+ import yaml
6
+ import torch
7
+ import nltk
8
+
9
+ def load_configs(config_file: str) -> dict:
10
+ with open(config_file, "r") as f:
11
+ configs = yaml.safe_load(f)
12
+
13
+ return configs
14
+
15
+
16
+ class RAGModel:
17
+ def __init__(self, configs) -> None:
18
+ self.configs = configs
19
+ self.device = configs["model"]["device"]
20
+ model_url = configs["model"]["genration_model"]
21
+ # quantization_config = BitsAndBytesConfig(
22
+ # load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
23
+ # )
24
+
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_url,
27
+ torch_dtype=torch.float16,
28
+ # quantization_config=quantization_config,
29
+ low_cpu_mem_usage=False,
30
+ attn_implementation="sdpa",
31
+ ).to(self.device)
32
+ self.tokenizer = AutoTokenizer.from_pretrained(
33
+ model_url,
34
+ )
35
+
36
+ def create_prompt(self, query, topk_items: list[str]):
37
+
38
+ context = "\n-".join(c for c in topk_items)
39
+
40
+ base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
41
+ you have access to the internet and other relevent data related to the user's question.
42
+ Give time for yourself to read the context and user query and extract relevent data and then answer the query.
43
+ make sure your answers is as detailed as posssbile.
44
+ Do not return thinking process, just return the answer.
45
+ Give the output structured as a Wikipedia article.
46
+ Now use the following context items to answer the user query
47
+ context: {context}
48
+ user query : {query}
49
+ """
50
+
51
+ dialog_template = [{"role": "user", "content": base_prompt}]
52
+
53
+ prompt = self.tokenizer.apply_chat_template(
54
+ conversation=dialog_template, tokenize=False, add_feneration_prompt=True
55
+ )
56
+ return prompt
57
+
58
+ def answer_query(self, query: str, topk_items: list[str]):
59
+
60
+ prompt = self.create_prompt(query, topk_items)
61
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
62
+ output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
63
+ text = self.tokenizer.decode(output[0])
64
+ text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
65
+
66
+
67
+ return text
68
+
69
+ if __name__ == "__main__":
70
+ configs = load_configs(config_file="rag.configs.yml")
71
+ query = "The height of burj khalifa is 1000 meters and it was built in 2023. What is the height of burgj khalifa"
72
+ # g = GoogleSearch(query)
73
+ # data = g.all_page_data
74
+ # d = Document(data, 512)
75
+ # doc_chunks = d.doc()
76
+ # s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
77
+ # topk, u = s.semantic_search(query=query, k=32)
78
+ r = RAGModel(configs)
79
+ output = r.answer_query(query=query, topk_items=[""])
80
+ print(output)
rag.configs.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ document:
2
+ min_char_length: 512
3
+
4
+ model:
5
+ embeding_model: all-mpnet-base-v2
6
+ genration_model: google/gemma-7b-it
7
+ device : cuda
requirements.txt ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.29.2
2
+ albumentations==1.4.3
3
+ altair==5.3.0
4
+ attrs==23.2.0
5
+ beautifulsoup4==4.12.3
6
+ bitsandbytes==0.43.1
7
+ blinker==1.7.0
8
+ cachetools==5.3.3
9
+ certifi==2024.2.2
10
+ charset-normalizer==2.0.4
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ contourpy==1.2.1
14
+ cycler==0.12.1
15
+ filelock==3.13.1
16
+ fonttools==4.50.0
17
+ fsspec==2024.3.1
18
+ gitdb==4.0.11
19
+ GitPython==3.1.43
20
+ gmpy2==2.1.2
21
+ huggingface-hub==0.22.2
22
+ idna==3.4
23
+ imageio==2.34.0
24
+ importlib_resources==6.4.0
25
+ Jinja2==3.1.3
26
+ joblib==1.3.2
27
+ jsonschema==4.21.1
28
+ jsonschema-specifications==2023.12.1
29
+ kiwisolver==1.4.5
30
+ lazy_loader==0.4
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==2.1.3
33
+ matplotlib==3.8.4
34
+ mdurl==0.1.2
35
+ mkl-fft==1.3.8
36
+ mkl-random==1.2.4
37
+ mkl-service==2.4.0
38
+ mpmath==1.3.0
39
+ networkx==3.1
40
+ nltk==3.8.1
41
+ numpy==1.26.4
42
+ opencv-python-headless==4.9.0.80
43
+ packaging==24.0
44
+ pandas==2.2.2
45
+ pillow==10.2.0
46
+ pip==23.3.1
47
+ protobuf==4.25.3
48
+ psutil==5.9.8
49
+ pyarrow==16.0.0
50
+ pydeck==0.8.1b0
51
+ Pygments==2.17.2
52
+ pyparsing==3.1.2
53
+ python-dateutil==2.9.0.post0
54
+ pytz==2024.1
55
+ PyYAML==6.0.1
56
+ referencing==0.34.0
57
+ regex==2024.4.16
58
+ requests==2.31.0
59
+ rich==13.7.1
60
+ rpds-py==0.18.0
61
+ safetensors==0.4.3
62
+ scikit-image==0.22.0
63
+ scikit-learn==1.4.1.post1
64
+ scipy==1.13.0
65
+ sentence-transformers==2.7.0
66
+ setuptools==68.2.2
67
+ six==1.16.0
68
+ smmap==5.0.1
69
+ soupsieve==2.5
70
+ streamlit==1.33.0
71
+ sympy==1.12
72
+ tenacity==8.2.3
73
+ threadpoolctl==3.4.0
74
+ tifffile==2024.2.12
75
+ tokenizers==0.15.2
76
+ toml==0.10.2
77
+ toolz==0.12.1
78
+ torch==2.2.2
79
+ torchaudio==2.2.2
80
+ torchvision==0.17.2
81
+ tornado==6.4
82
+ tqdm==4.66.2
83
+ transformers==4.39.3
84
+ typing_extensions==4.9.0
85
+ tzdata==2024.1
86
+ urllib3==2.1.0
87
+ watchdog==4.0.0
88
+ wheel==0.41.2
89
+ zipp==3.18.1
90
+ accelerate==0.29.2
91
+ albumentations==1.4.3
92
+ altair==5.3.0
93
+ attrs==23.2.0
94
+ beautifulsoup4==4.12.3
95
+ bitsandbytes==0.43.1
96
+ blinker==1.7.0
97
+ cachetools==5.3.3
98
+ certifi==2024.2.2
99
+ charset-normalizer==2.0.4
100
+ click==8.1.7
101
+ colorama==0.4.6
102
+ contourpy==1.2.1
103
+ cycler==0.12.1
104
+ filelock==3.13.1
105
+ fonttools==4.50.0
106
+ fsspec==2024.3.1
107
+ gitdb==4.0.11
108
+ GitPython==3.1.43
109
+ gmpy2==2.1.2
110
+ huggingface-hub==0.22.2
111
+ idna==3.4
112
+ imageio==2.34.0
113
+ importlib_resources==6.4.0
114
+ Jinja2==3.1.3
115
+ joblib==1.3.2
116
+ jsonschema==4.21.1
117
+ jsonschema-specifications==2023.12.1
118
+ kiwisolver==1.4.5
119
+ lazy_loader==0.4
120
+ markdown-it-py==3.0.0
121
+ MarkupSafe==2.1.3
122
+ matplotlib==3.8.4
123
+ mdurl==0.1.2
124
+ mkl-fft==1.3.8
125
+ mkl-random==1.2.4
126
+ mkl-service==2.4.0
127
+ mpmath==1.3.0
128
+ networkx==3.1
129
+ nltk==3.8.1
130
+ numpy==1.26.4
131
+ opencv-python-headless==4.9.0.80
132
+ packaging==24.0
133
+ pandas==2.2.2
134
+ pillow==10.2.0
135
+ pip==23.3.1
136
+ protobuf==4.25.3
137
+ psutil==5.9.8
138
+ pyarrow==16.0.0
139
+ pydeck==0.8.1b0
140
+ Pygments==2.17.2
141
+ pyparsing==3.1.2
142
+ python-dateutil==2.9.0.post0
143
+ pytz==2024.1
144
+ PyYAML==6.0.1
145
+ referencing==0.34.0
146
+ regex==2024.4.16
147
+ requests==2.31.0
148
+ rich==13.7.1
149
+ rpds-py==0.18.0
150
+ safetensors==0.4.3
151
+ scikit-image==0.22.0
152
+ scikit-learn==1.4.1.post1
153
+ scipy==1.13.0
154
+ sentence-transformers==2.7.0
155
+ setuptools==68.2.2
156
+ six==1.16.0
157
+ smmap==5.0.1
158
+ soupsieve==2.5
159
+ streamlit==1.33.0
160
+ sympy==1.12
161
+ tenacity==8.2.3
162
+ threadpoolctl==3.4.0
163
+ tifffile==2024.2.12
164
+ tokenizers==0.15.2
165
+ toml==0.10.2
166
+ toolz==0.12.1
167
+ torch==2.2.2
168
+ torchaudio==2.2.2
169
+ torchvision==0.17.2
170
+ tornado==6.4
171
+ tqdm==4.66.2
172
+ transformers==4.39.3
173
+ typing_extensions==4.9.0
174
+ tzdata==2024.1
175
+ urllib3==2.1.0
176
+ watchdog==4.0.0
177
+ wheel==0.41.2
178
+ zipp==3.18.1
search.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bs4 import BeautifulSoup
2
+ import urllib
3
+ import requests
4
+ import nltk
5
+ import torch
6
+ from typing import Union
7
+ from sentence_transformers import SentenceTransformer, util
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+
10
+
11
+ class GoogleSearch:
12
+ def __init__(self, query: str) -> None:
13
+ self.query = query
14
+ escaped_query = urllib.parse.quote_plus(query)
15
+ self.URL = f"https://www.google.com/search?q={escaped_query}"
16
+
17
+ self.headers = {
18
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3538.102 Safari/537.36"
19
+ }
20
+ self.links = self.get_initial_links()
21
+ self.all_page_data = self.all_pages()
22
+
23
+ def clean_urls(self, anchors: list[str]) -> list[str]:
24
+
25
+ links: list[str] = []
26
+ for a in anchors:
27
+ links.append(
28
+ list(filter(lambda l: l.startswith("url=http"), a["href"].split("&")))
29
+ )
30
+
31
+ links = [
32
+ link.split("url=")[-1]
33
+ for sublist in links
34
+ for link in sublist
35
+ if len(link) > 0
36
+ ]
37
+
38
+ return links
39
+
40
+ def read_url_page(self, url: str) -> str:
41
+
42
+ response = requests.get(url, headers=self.headers)
43
+ response.raise_for_status()
44
+ soup = BeautifulSoup(response.text, "html.parser")
45
+ return soup.get_text(strip=True)
46
+
47
+ def get_initial_links(self) -> list[str]:
48
+ """
49
+ scrape google for the query with keyword based search
50
+ """
51
+ print("Searching Google...")
52
+ response = requests.get(self.URL, headers=self.headers)
53
+ soup = BeautifulSoup(response.text, "html.parser")
54
+ anchors = soup.find_all("a", href=True)
55
+ return self.clean_urls(anchors)
56
+
57
+ def all_pages(self) -> list[tuple[str, str]]:
58
+
59
+ data: list[tuple[str, str]] = []
60
+ with ThreadPoolExecutor(max_workers=4) as executor:
61
+
62
+ future_to_url = {
63
+ executor.submit(self.read_url_page, url): url for url in self.links
64
+ }
65
+ for future in as_completed(future_to_url):
66
+ url = future_to_url[future]
67
+ try:
68
+ output = future.result()
69
+ data.append((url, output))
70
+
71
+ except requests.exceptions.HTTPError as e:
72
+ print(e)
73
+
74
+ # for url in self.links:
75
+ # try:
76
+ # data.append((url, self.read_url_page(url)))
77
+ # except requests.exceptions.HTTPError as e:
78
+ # print(e)
79
+
80
+ return data
81
+
82
+
83
+ class Document:
84
+
85
+ def __init__(self, data: list[tuple[str, str]], min_char_len: int) -> None:
86
+ """
87
+ data : list[tuple[str, str]]
88
+ url and page data
89
+ """
90
+ self.data = data
91
+ self.min_char_len = min_char_len
92
+
93
+ def make_min_len_chunk(self):
94
+ raise NotImplementedError
95
+
96
+ def chunk_page(
97
+ self,
98
+ page_text: str,
99
+ ) -> list[str]:
100
+
101
+ min_len_chunks: list[str] = []
102
+ chunk_text = nltk.tokenize.sent_tokenize(page_text)
103
+ sentence: str = ""
104
+ for sent in chunk_text:
105
+ if len(sentence) > self.min_char_len:
106
+ min_len_chunks.append(sentence)
107
+ sent = ""
108
+ sentence = ""
109
+ else:
110
+ sentence += sent
111
+ return min_len_chunks
112
+
113
+ def doc(self) -> tuple[list[str], list[str]]:
114
+ print("Creating Document...")
115
+ chunked_data: list[str] = []
116
+ urls: list[str] = []
117
+ for url, dataitem in self.data:
118
+ data = self.chunk_page(dataitem)
119
+ chunked_data.append(data)
120
+ urls.append(url)
121
+
122
+ chunked_data = [chunk for sublist in chunked_data for chunk in sublist]
123
+ return chunked_data, url
124
+
125
+
126
+ class SemanticSearch:
127
+ def __init__(
128
+ self, doc_chunks: tuple[list, list], model_path: str, device: str
129
+ ) -> None:
130
+
131
+ self.doc_chunks, self.urls = doc_chunks
132
+ self.st = SentenceTransformer(
133
+ model_path,
134
+ device,
135
+ )
136
+
137
+ def semantic_search(self, query: str, k: int = 10):
138
+ print("Searching Top k in document...")
139
+ query_embeding = self.get_embeding(query)
140
+ doc_embeding = self.get_embeding(self.doc_chunks)
141
+ scores = util.dot_score(a=query_embeding, b=doc_embeding)[0]
142
+
143
+ top_k = torch.topk(scores, k=k)[1].cpu().tolist()
144
+ return [self.doc_chunks[i] for i in top_k], self.urls
145
+
146
+ def get_embeding(self, text: Union[list[str], str]):
147
+ en = self.st.encode(text)
148
+ return en