frankjosh commited on
Commit
c60740f
·
verified ·
1 Parent(s): a2df113

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -69
app.py CHANGED
@@ -1,8 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- """app.py
3
- Enhanced Repository Recommender System using Streamlit and CodeT5-small
4
- """
5
-
6
  import warnings
7
  warnings.filterwarnings('ignore')
8
 
@@ -15,6 +10,9 @@ import torch
15
  from tqdm import tqdm
16
  from datasets import load_dataset
17
  from datetime import datetime
 
 
 
18
 
19
  # Configure GPU if available
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -25,94 +23,163 @@ if 'history' not in st.session_state:
25
  if 'feedback' not in st.session_state:
26
  st.session_state.feedback = {}
27
 
28
- # Step 1: Load Dataset and Precompute Embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @st.cache_resource
30
  def load_data_and_model():
31
- """
32
- Load the dataset and precompute embeddings. Load the CodeT5-small model and tokenizer.
33
- """
34
  try:
35
- # Download and load dataset
36
  dataset = load_dataset("frankjosh/filtered_dataset")
37
  data = pd.DataFrame(dataset['train'])
38
-
39
- # Ensure required columns exist
40
- required_columns = ['docstring', 'summary']
41
- for col in required_columns:
42
- if col not in data.columns:
43
- st.error(f"Missing required column: {col}")
44
- st.stop()
45
-
46
- # Combine text fields for embedding generation
47
  data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
48
- except Exception as e:
49
- st.error(f"Error loading dataset: {str(e)}")
50
- st.stop()
51
-
52
- # Load CodeT5-small model and tokenizer
53
- model_name = "Salesforce/codet5-small"
54
- try:
55
  tokenizer = AutoTokenizer.from_pretrained(model_name)
56
  model = AutoModel.from_pretrained(model_name)
57
-
58
- # Move model to GPU if available
59
  if torch.cuda.is_available():
60
- model = model.to('cuda')
61
- model.eval() # Set to evaluation mode
 
 
 
62
  except Exception as e:
63
- st.error(f"Error loading model: {str(e)}")
64
  st.stop()
65
 
66
- return data, tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Define the embedding generation function
69
- @st.cache_data
70
- def generate_embedding(_model, _tokenizer, text):
71
- inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
72
- if torch.cuda.is_available():
73
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
74
  with torch.no_grad():
75
- outputs = _model.encoder(**inputs)
76
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
77
- if torch.cuda.is_available():
78
- embedding = embedding.cpu()
79
- return embedding.numpy()
80
-
81
- # Precompute embeddings for dataset
82
- def precompute_embeddings(data, model, tokenizer):
 
 
 
 
 
 
 
 
 
83
  embeddings = []
84
- for text in tqdm(data['text'], desc="Generating embeddings"):
85
- embedding = generate_embedding(model, tokenizer, text)
86
- embeddings.append(embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  data['embedding'] = embeddings
88
  return data
89
 
90
- # Generate a concise case study brief from repository data
91
- def generate_case_study(repo_data):
92
- template = f"""
93
- **Project Overview**: {repo_data['summary'][:50]}...
 
 
 
 
 
 
 
 
 
 
94
 
95
- **Key Features**:
96
- - Repository contains production-ready {repo_data['path'].split('/')[-1]} implementation
97
- - {repo_data['docstring'][:50]}...
 
 
98
 
99
- **Potential Applications**: This repository can be utilized for projects requiring {' '.join(repo_data['summary'].split()[:5])}...
 
100
 
101
- **Implementation Complexity**: {'Medium' if len(repo_data['docstring']) > 500 else 'Low'}
 
102
 
103
- **Integration Potential**: {'High' if 'api' in repo_data['text'].lower() or 'interface' in repo_data['text'].lower() else 'Medium'}
104
- """
105
- return template[:150] + "..."
106
 
107
- # Save user feedback for a repository
108
- def save_feedback(repo_id, feedback_type):
109
- if repo_id not in st.session_state.feedback:
110
- st.session_state.feedback[repo_id] = {'likes': 0, 'dislikes': 0}
111
- st.session_state.feedback[repo_id][feedback_type] += 1
112
 
113
- # Load resources
114
- data, tokenizer, model = load_data_and_model()
115
- data = precompute_embeddings(data, model, tokenizer)
116
 
117
  # Main App Interface
118
  st.title("Enhanced Repository Recommender System 🚀")
 
 
 
 
 
 
1
  import warnings
2
  warnings.filterwarnings('ignore')
3
 
 
10
  from tqdm import tqdm
11
  from datasets import load_dataset
12
  from datetime import datetime
13
+ from typing import List, Dict, Any
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from functools import partial
16
 
17
  # Configure GPU if available
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
23
  if 'feedback' not in st.session_state:
24
  st.session_state.feedback = {}
25
 
26
+ # Define subset size
27
+ SUBSET_SIZE = 1000 # Starting with 1000 items for quick testing
28
+
29
+ class TextDataset(Dataset):
30
+ def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
31
+ self.texts = texts
32
+ self.tokenizer = tokenizer
33
+ self.max_length = max_length
34
+
35
+ def __len__(self):
36
+ return len(self.texts)
37
+
38
+ def __getitem__(self, idx):
39
+ return self.tokenizer(
40
+ self.texts[idx],
41
+ padding='max_length',
42
+ truncation=True,
43
+ max_length=self.max_length,
44
+ return_tensors="pt"
45
+ )
46
+
47
  @st.cache_resource
48
  def load_data_and_model():
49
+ """Load the dataset and model with optimized memory usage"""
 
 
50
  try:
51
+ # Load dataset
52
  dataset = load_dataset("frankjosh/filtered_dataset")
53
  data = pd.DataFrame(dataset['train'])
54
+
55
+ # Take a random subset
56
+ data = data.sample(n=min(SUBSET_SIZE, len(data)), random_state=42).reset_index(drop=True)
57
+
58
+ # Combine text fields
 
 
 
 
59
  data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
60
+
61
+ # Load model and tokenizer
62
+ model_name = "Salesforce/codet5-small"
 
 
 
 
63
  tokenizer = AutoTokenizer.from_pretrained(model_name)
64
  model = AutoModel.from_pretrained(model_name)
65
+
 
66
  if torch.cuda.is_available():
67
+ model = model.to(device)
68
+
69
+ model.eval()
70
+ return data, tokenizer, model
71
+
72
  except Exception as e:
73
+ st.error(f"Error in initialization: {str(e)}")
74
  st.stop()
75
 
76
+ def collate_fn(batch, pad_token_id):
77
+ max_length = max(inputs['input_ids'].shape[1] for inputs in batch)
78
+ input_ids = []
79
+ attention_mask = []
80
+
81
+ for inputs in batch:
82
+ input_ids.append(torch.nn.functional.pad(
83
+ inputs['input_ids'].squeeze(),
84
+ (0, max_length - inputs['input_ids'].shape[1]),
85
+ value=pad_token_id
86
+ ))
87
+ attention_mask.append(torch.nn.functional.pad(
88
+ inputs['attention_mask'].squeeze(),
89
+ (0, max_length - inputs['attention_mask'].shape[1]),
90
+ value=0
91
+ ))
92
+
93
+ return {
94
+ 'input_ids': torch.stack(input_ids),
95
+ 'attention_mask': torch.stack(attention_mask)
96
+ }
97
 
98
+ def generate_embeddings_batch(model, batch, device):
99
+ """Generate embeddings for a batch of inputs"""
 
 
 
 
100
  with torch.no_grad():
101
+ batch = {k: v.to(device) for k, v in batch.items()}
102
+ outputs = model.encoder(**batch)
103
+ embeddings = outputs.last_hidden_state.mean(dim=1)
104
+ return embeddings.cpu().numpy()
105
+
106
+ def precompute_embeddings(data: pd.DataFrame, model, tokenizer, batch_size: int = 16):
107
+ """Precompute embeddings with batching and progress tracking"""
108
+ dataset = TextDataset(data['text'].tolist(), tokenizer)
109
+ dataloader = DataLoader(
110
+ dataset,
111
+ batch_size=batch_size,
112
+ shuffle=False,
113
+ collate_fn=partial(collate_fn, pad_token_id=tokenizer.pad_token_id),
114
+ num_workers=2, # Reduced workers for smaller dataset
115
+ pin_memory=True
116
+ )
117
+
118
  embeddings = []
119
+ total_batches = len(dataloader)
120
+
121
+ # Create a progress bar
122
+ progress_bar = st.progress(0)
123
+ status_text = st.empty()
124
+
125
+ start_time = datetime.now()
126
+
127
+ for i, batch in enumerate(dataloader):
128
+ # Generate embeddings for batch
129
+ batch_embeddings = generate_embeddings_batch(model, batch, device)
130
+ embeddings.extend(batch_embeddings)
131
+
132
+ # Update progress
133
+ progress = (i + 1) / total_batches
134
+ progress_bar.progress(progress)
135
+
136
+ # Calculate and display ETA
137
+ elapsed_time = (datetime.now() - start_time).total_seconds()
138
+ eta = (elapsed_time / (i + 1)) * (total_batches - (i + 1))
139
+ status_text.text(f"Processing batch {i+1}/{total_batches}. ETA: {int(eta)} seconds")
140
+
141
+ progress_bar.empty()
142
+ status_text.empty()
143
+
144
+ # Add embeddings to dataframe
145
  data['embedding'] = embeddings
146
  return data
147
 
148
+ @torch.no_grad()
149
+ def generate_query_embedding(model, tokenizer, query: str) -> np.ndarray:
150
+ """Generate embedding for a single query"""
151
+ inputs = tokenizer(
152
+ query,
153
+ return_tensors="pt",
154
+ padding=True,
155
+ truncation=True,
156
+ max_length=512
157
+ ).to(device)
158
+
159
+ outputs = model.encoder(**inputs)
160
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
161
+ return embedding.squeeze()
162
 
163
+ def find_similar_repos(query_embedding: np.ndarray, data: pd.DataFrame, top_n: int = 5) -> pd.DataFrame:
164
+ """Find similar repositories using vectorized operations"""
165
+ similarities = cosine_similarity([query_embedding], np.stack(data['embedding'].values))[0]
166
+ data['similarity'] = similarities
167
+ return data.nlargest(top_n, 'similarity')
168
 
169
+ # Load resources
170
+ data, tokenizer, model = load_data_and_model()
171
 
172
+ # Add info about subset size
173
+ st.info(f"Running with a subset of {SUBSET_SIZE} repositories for testing purposes.")
174
 
175
+ # Precompute embeddings for the subset
176
+ data = precompute_embeddings(data, model, tokenizer)
 
177
 
178
+ # Main App Interface
179
+ st.title("Repository Recommender System 🚀")
180
+ st.caption("Testing Version - Running on subset of data")
 
 
181
 
182
+ # Rest of your UI code remains the same...
 
 
183
 
184
  # Main App Interface
185
  st.title("Enhanced Repository Recommender System 🚀")