Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from datasets import load_dataset # For loading dataset | |
| from datetime import datetime | |
| from typing import List, Dict, Any | |
| from functools import partial | |
| # Configure GPU if available | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Initialize session state | |
| if 'history' not in st.session_state: | |
| st.session_state.history = [] | |
| if 'feedback' not in st.session_state: | |
| st.session_state.feedback = {} | |
| # Define subset size and batch size for optimization | |
| SUBSET_SIZE = 500 # Smaller subset for faster precomputation | |
| BATCH_SIZE = 8 # Smaller batch size to reduce memory overhead | |
| # Caching key resources: Model, Tokenizer, and Precomputed Embeddings | |
| def load_model_and_tokenizer(): | |
| """ | |
| Load the pre-trained model and tokenizer using Hugging Face Transformers. | |
| Cached to ensure it loads only once. | |
| """ | |
| model_name = "Salesforce/codet5-small" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| return tokenizer, model | |
| def load_data(): | |
| """ | |
| Load and sample the dataset from Hugging Face. | |
| Ensures the 'text' column is created for embedding precomputation. | |
| """ | |
| dataset = load_dataset("frankjosh/filtered_dataset") | |
| data = pd.DataFrame(dataset['train']) | |
| # Take a random subset of data | |
| data = data.sample(n=min(SUBSET_SIZE, len(data)), random_state=42).reset_index(drop=True) | |
| # Create a 'text' column by combining relevant fields | |
| data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('') | |
| return data | |
| def precompute_embeddings(data: pd.DataFrame, _tokenizer, _model, batch_size=BATCH_SIZE): | |
| """ | |
| Precompute embeddings for repository metadata to optimize query performance. | |
| The tokenizer and model are excluded from caching as they are unhashable. | |
| """ | |
| class TextDataset(Dataset): | |
| def __init__(self, texts: List[str], tokenizer, max_length=512): | |
| self.texts = texts | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| return self.tokenizer( | |
| self.texts[idx], | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt" | |
| ) | |
| def collate_fn(batch, pad_token_id): | |
| max_length = max(inputs['input_ids'].shape[1] for inputs in batch) | |
| input_ids, attention_mask = [], [] | |
| for inputs in batch: | |
| input_ids.append(torch.nn.functional.pad( | |
| inputs['input_ids'].squeeze(), | |
| (0, max_length - inputs['input_ids'].shape[1]), | |
| value=pad_token_id | |
| )) | |
| attention_mask.append(torch.nn.functional.pad( | |
| inputs['attention_mask'].squeeze(), | |
| (0, max_length - inputs['attention_mask'].shape[1]), | |
| value=0 | |
| )) | |
| return { | |
| 'input_ids': torch.stack(input_ids), | |
| 'attention_mask': torch.stack(attention_mask) | |
| } | |
| def generate_embeddings_batch(model, batch, device): | |
| with torch.no_grad(): | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| outputs = model.encoder(**batch) | |
| return outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| dataset = TextDataset(data['text'].tolist(), _tokenizer) | |
| dataloader = DataLoader( | |
| dataset, batch_size=batch_size, shuffle=False, | |
| collate_fn=partial(collate_fn, pad_token_id=_tokenizer.pad_token_id) | |
| ) | |
| embeddings = [] | |
| for batch in dataloader: | |
| batch_embeddings = generate_embeddings_batch(_model, batch, device) | |
| embeddings.extend(batch_embeddings) | |
| data['embedding'] = embeddings | |
| return data | |
| def generate_query_embedding(model, tokenizer, query: str) -> np.ndarray: | |
| """ | |
| Generate embedding for a user query using the pre-trained model. | |
| """ | |
| inputs = tokenizer( | |
| query, return_tensors="pt", padding=True, | |
| truncation=True, max_length=512 | |
| ).to(device) | |
| outputs = model.encoder(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| def find_similar_repos(query_embedding: np.ndarray, data: pd.DataFrame, top_n=5) -> pd.DataFrame: | |
| """ | |
| Compute cosine similarity and return the top N most similar repositories. | |
| """ | |
| similarities = cosine_similarity([query_embedding], np.stack(data['embedding'].values))[0] | |
| data['similarity'] = similarities | |
| return data.nlargest(top_n, 'similarity') | |
| def display_recommendations(recommendations: pd.DataFrame): | |
| """ | |
| Display the recommended repositories in the Streamlit app interface. | |
| """ | |
| st.markdown("### π― Top Recommendations") | |
| for idx, row in recommendations.iterrows(): | |
| st.markdown(f"### {idx + 1}. {row['repo']}") | |
| st.metric("Match Score", f"{row['similarity']:.2%}") | |
| st.markdown(f"[View Repository]({row['url']})") | |
| # Main workflow | |
| st.title("Repository Recommender System π") | |
| st.caption("Find repositories based on your project description.") | |
| # Load resources | |
| tokenizer, model = load_model_and_tokenizer() | |
| data = load_data() | |
| data = precompute_embeddings(data, tokenizer, model) | |
| # User input | |
| user_query = st.text_area( | |
| "Describe your project:", height=150, | |
| placeholder="Example: A machine learning project for customer churn prediction..." | |
| ) | |
| if st.button("π Search Repositories"): | |
| if user_query.strip(): | |
| with st.spinner("Finding relevant repositories..."): | |
| query_embedding = generate_query_embedding(model, tokenizer, user_query) | |
| recommendations = find_similar_repos(query_embedding, data) | |
| display_recommendations(recommendations) | |
| else: | |
| st.error("Please provide a project description.") | |