jyo01 commited on
Commit
2afa6ec
·
verified ·
1 Parent(s): 6b18036

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +225 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import base64
4
+ import requests
5
+ import torch
6
+ import nest_asyncio
7
+ from fastapi import HTTPException
8
+ from pydantic import BaseModel
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
+ from sentence_transformers import SentenceTransformer, models
11
+ import gradio as gr
12
+
13
+ # Apply nest_asyncio to allow async operations in the notebook/Spaces
14
+ nest_asyncio.apply()
15
+
16
+ ############################################
17
+ # Configuration
18
+ ############################################
19
+
20
+ # Replace with your actual tokens
21
+ HF_TOKEN = "YOUR_HF_TOKEN"
22
+ GITHUB_TOKEN = "YOUR_GITHUB_TOKEN"
23
+
24
+ ############################################
25
+ # GitHub API Functions
26
+ ############################################
27
+
28
+ def extract_repo_info(github_url: str):
29
+ pattern = r"github\.com/([^/]+)/([^/]+)"
30
+ match = re.search(pattern, github_url)
31
+ if match:
32
+ owner = match.group(1)
33
+ repo = match.group(2).replace('.git', '')
34
+ return owner, repo
35
+ else:
36
+ raise ValueError("Invalid GitHub URL provided.")
37
+
38
+ def get_repo_metadata(owner: str, repo: str):
39
+ headers = {'Authorization': f'token {GITHUB_TOKEN}'}
40
+ repo_url = f"https://api.github.com/repos/{owner}/{repo}"
41
+ response = requests.get(repo_url, headers=headers)
42
+ return response.json()
43
+
44
+ def get_repo_tree(owner: str, repo: str, branch: str):
45
+ headers = {'Authorization': f'token {GITHUB_TOKEN}'}
46
+ tree_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
47
+ response = requests.get(tree_url, headers=headers)
48
+ return response.json()
49
+
50
+ def get_file_content(owner: str, repo: str, file_path: str):
51
+ headers = {'Authorization': f'token {GITHUB_TOKEN}'}
52
+ content_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{file_path}"
53
+ response = requests.get(content_url, headers=headers)
54
+ data = response.json()
55
+ if 'content' in data:
56
+ return base64.b64decode(data['content']).decode('utf-8')
57
+ else:
58
+ return None
59
+
60
+ ############################################
61
+ # Embedding Functions
62
+ ############################################
63
+
64
+ def preprocess_text(text: str) -> str:
65
+ cleaned_text = text.strip()
66
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
67
+ return cleaned_text
68
+
69
+ def load_embedding_model(model_name: str = 'huggingface/CodeBERTa-small-v1') -> SentenceTransformer:
70
+ transformer_model = models.Transformer(model_name)
71
+ pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(),
72
+ pooling_mode_mean_tokens=True)
73
+ model = SentenceTransformer(modules=[transformer_model, pooling_model])
74
+ return model
75
+
76
+ def generate_embedding(text: str, model_name: str = 'huggingface/CodeBERTa-small-v1') -> list:
77
+ processed_text = preprocess_text(text)
78
+ model = load_embedding_model(model_name)
79
+ embedding = model.encode(processed_text)
80
+ return embedding
81
+
82
+ ############################################
83
+ # LLM Integration Functions
84
+ ############################################
85
+
86
+ def is_detailed_query(query: str) -> bool:
87
+ keywords = ["detail", "detailed", "thorough", "in depth", "comprehensive", "extensive"]
88
+ return any(keyword in query.lower() for keyword in keywords)
89
+
90
+ def generate_prompt(query: str, context_snippets: list) -> str:
91
+ context = "\n\n".join(context_snippets)
92
+ if is_detailed_query(query):
93
+ instruction = "Provide an extremely detailed and thorough explanation of at least 500 words."
94
+ else:
95
+ instruction = "Answer concisely."
96
+
97
+ prompt = (
98
+ f"Below is some context from a GitHub repository:\n\n"
99
+ f"{context}\n\n"
100
+ f"Based on the above, {instruction}\n{query}\n"
101
+ f"Answer:"
102
+ )
103
+ return prompt
104
+
105
+ def get_llm_response(prompt: str, model_name: str = "meta-llama/Llama-2-7b-chat-hf", max_new_tokens: int = None) -> str:
106
+ if max_new_tokens is None:
107
+ max_new_tokens = 1024 if is_detailed_query(prompt) else 256
108
+
109
+ torch.cuda.empty_cache()
110
+
111
+ # Load tokenizer and model with authentication using the 'token' parameter.
112
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)
113
+ model = AutoModelForCausalLM.from_pretrained(
114
+ model_name,
115
+ device_map="auto",
116
+ use_safetensors=False,
117
+ trust_remote_code=True,
118
+ torch_dtype=torch.float16,
119
+ token=HF_TOKEN
120
+ )
121
+
122
+ text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
123
+ outputs = text_gen(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
124
+ full_response = outputs[0]['generated_text']
125
+
126
+ marker = "Answer:"
127
+ if marker in full_response:
128
+ answer = full_response.split(marker, 1)[1].strip()
129
+ else:
130
+ answer = full_response.strip()
131
+
132
+ return answer
133
+
134
+ ############################################
135
+ # Gradio Interface Functions
136
+ ############################################
137
+
138
+ def load_repo_contents(github_url: str):
139
+ try:
140
+ owner, repo = extract_repo_info(github_url)
141
+ except Exception as e:
142
+ return f"Error: {str(e)}"
143
+ repo_data = get_repo_metadata(owner, repo)
144
+ default_branch = repo_data.get("default_branch", "main")
145
+ tree_data = get_repo_tree(owner, repo, default_branch)
146
+ if "tree" not in tree_data:
147
+ return "Error: Could not fetch repository tree."
148
+ file_list = [item["path"] for item in tree_data["tree"] if item["type"] == "blob"]
149
+ return file_list
150
+
151
+ def get_file_content_for_choice(github_url: str, file_choice: int):
152
+ try:
153
+ owner, repo = extract_repo_info(github_url)
154
+ except Exception as e:
155
+ return str(e)
156
+ repo_data = get_repo_metadata(owner, repo)
157
+ default_branch = repo_data.get("default_branch", "main")
158
+ tree_data = get_repo_tree(owner, repo, default_branch)
159
+ if "tree" not in tree_data:
160
+ return "Error: Could not fetch repository tree."
161
+ file_list = [item["path"] for item in tree_data["tree"] if item["type"] == "blob"]
162
+ if file_choice < 1 or file_choice > len(file_list):
163
+ return "Error: Invalid file choice."
164
+ selected_file = file_list[file_choice - 1]
165
+ content = get_file_content(owner, repo, selected_file)
166
+ return content, selected_file
167
+
168
+ def chat_with_file(github_url: str, file_choice: int, user_query: str):
169
+ result = get_file_content_for_choice(github_url, file_choice)
170
+ if isinstance(result, str):
171
+ return result # Error message
172
+ file_content, selected_file = result
173
+ preprocessed = preprocess_text(file_content)
174
+ context_snippet = preprocessed[:1000] # use first 1000 characters as context
175
+ prompt = generate_prompt(user_query, [context_snippet])
176
+ llm_response = get_llm_response(prompt)
177
+ return f"File: {selected_file}\n\nLLM Response:\n{llm_response}"
178
+
179
+ ############################################
180
+ # Gradio Interface Setup
181
+ ############################################
182
+
183
+ with gr.Blocks() as demo:
184
+ gr.Markdown("# RepoChat - Chat with Repository Files")
185
+
186
+ with gr.Row():
187
+ with gr.Column(scale=1):
188
+ gr.Markdown("### Repository Information")
189
+ github_url_input = gr.Textbox(label="GitHub Repository URL", placeholder="https://github.com/username/repository")
190
+ load_repo_btn = gr.Button("Load Repository Contents")
191
+ file_dropdown = gr.Dropdown(label="Select a File", interactive=True)
192
+ repo_content_output = gr.Textbox(label="File Content", interactive=False, lines=10)
193
+ with gr.Column(scale=2):
194
+ gr.Markdown("### Chat Interface")
195
+ chat_query_input = gr.Textbox(label="Your Query", placeholder="Type your query here")
196
+ chat_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=10)
197
+ chat_btn = gr.Button("Send Query")
198
+
199
+ # When clicking "Load Repository Contents", update file dropdown
200
+ def update_file_dropdown(github_url):
201
+ files = load_repo_contents(github_url)
202
+ return files
203
+
204
+ load_repo_btn.click(fn=update_file_dropdown, inputs=[github_url_input], outputs=[file_dropdown])
205
+
206
+ # When file selection changes, update file content display
207
+ def update_repo_content(github_url, file_choice):
208
+ if not file_choice:
209
+ return "No file selected."
210
+ try:
211
+ file_index = int(file_choice)
212
+ except:
213
+ file_index = 1
214
+ content, _ = get_file_content_for_choice(github_url, file_index)
215
+ return content
216
+
217
+ file_dropdown.change(fn=update_repo_content, inputs=[github_url_input, file_dropdown], outputs=[repo_content_output])
218
+
219
+ # When sending a chat query, process it
220
+ def process_chat(github_url, file_choice, chat_query):
221
+ return chat_with_file(github_url, int(file_choice), chat_query)
222
+
223
+ chat_btn.click(fn=process_chat, inputs=[github_url_input, file_dropdown, chat_query_input], outputs=[chat_output])
224
+
225
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ nest_asyncio
4
+ requests
5
+ torch
6
+ transformers
7
+ sentence-transformers
8
+ gradio