Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1deINvEblsMkv9h0gJzuGB4uSamW0DMX5 | |
""" | |
#pip install streamlit transformers gdown torch pandas numpy | |
import warnings | |
warnings.filterwarnings('ignore') | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import gdown | |
from pathlib import Path | |
from datetime import datetime | |
import json | |
import torch.cuda | |
import os | |
# Configure GPU if available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Initialize session state | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
if 'feedback' not in st.session_state: | |
st.session_state.feedback = {} | |
# Configuration | |
DATASET_GDRIVE_ID = "1pPYlUEtIA3bi8iLVKqzF-37sHoaOhTZz" # Replace with your actual file ID | |
LOCAL_DATA_DIR = "data" | |
DATASET_FILENAME = "filtered_dataset.parquet" | |
def download_from_gdrive(): | |
""" | |
Download dataset from Google Drive with proper error handling | |
""" | |
os.makedirs(LOCAL_DATA_DIR, exist_ok=True) | |
local_path = os.path.join(LOCAL_DATA_DIR, DATASET_FILENAME) | |
if not os.path.exists(local_path): | |
try: | |
with st.spinner('Downloading dataset from Google Drive... This might take a few minutes...'): | |
# Create direct download URL | |
url = f'https://drive.google.com/uc?id={DATASET_GDRIVE_ID}' | |
# Download file | |
gdown.download(url, local_path, quiet=False) | |
if os.path.exists(local_path): | |
st.success("Dataset downloaded successfully!") | |
else: | |
st.error("Failed to download dataset") | |
st.stop() | |
except Exception as e: | |
st.error(f"Error downloading dataset: {str(e)}") | |
st.stop() | |
return local_path | |
# Step 1: Load Dataset and Precompute Embeddings | |
def load_data_and_model(): | |
""" | |
Load the dataset and precompute embeddings. Load the CodeT5-small model and tokenizer. | |
""" | |
try: | |
# Download and load dataset | |
dataset_path = download_from_gdrive() | |
data = pd.read_parquet(dataset_path) | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
st.stop() | |
# Combine text fields for embedding generation | |
data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('') | |
# Load CodeT5-small model and tokenizer | |
model_name = "Salesforce/codet5-small" | |
def load_model_and_tokenizer(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
# Move model to GPU if available | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
model.eval() # Set to evaluation mode | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.stop() | |
tokenizer, model = load_model_and_tokenizer() | |
# Precompute embeddings with GPU support | |
def generate_embedding(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# Move inputs to GPU if available | |
if torch.cuda.is_available(): | |
inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model.encoder(**inputs) | |
# Move output back to CPU if needed | |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze() | |
if torch.cuda.is_available(): | |
embedding = embedding.cpu() | |
return embedding.numpy() | |
# Generate embeddings with progress bar | |
with st.spinner('Generating embeddings... This might take a few minutes on first run...'): | |
data['embedding'] = data['text'].apply(lambda x: generate_embedding(x)) | |
return data, tokenizer, model | |
def generate_case_study(repo_data): | |
""" | |
Generate a concise case study brief from repository data | |
""" | |
template = f""" | |
**Project Overview**: {repo_data['summary'][:50]}... | |
**Key Features**: | |
- Repository contains production-ready {repo_data['path'].split('/')[-1]} implementation | |
- {repo_data['docstring'][:50]}... | |
**Potential Applications**: This repository can be utilized for projects requiring {repo_data['summary'].split()[0:5]}... | |
**Implementation Complexity**: {'Medium' if len(repo_data['docstring']) > 500 else 'Low'} | |
**Integration Potential**: {'High' if 'api' in repo_data['text'].lower() or 'interface' in repo_data['text'].lower() else 'Medium'} | |
""" | |
return template[:150] + "..." | |
def save_feedback(repo_id, feedback_type): | |
""" | |
Save user feedback for a repository | |
""" | |
if repo_id not in st.session_state.feedback: | |
st.session_state.feedback[repo_id] = {'likes': 0, 'dislikes': 0} | |
st.session_state.feedback[repo_id][feedback_type] += 1 | |
# Main App | |
st.title("Enhanced Repository Recommender System π") | |
# Sidebar for History and Stats | |
with st.sidebar: | |
st.header("π Search History") | |
if st.session_state.history: | |
for idx, item in enumerate(st.session_state.history[-5:]): # Show last 5 searches | |
with st.expander(f"Search {len(st.session_state.history)-idx}: {item['query'][:30]}..."): | |
st.write(f"Time: {item['timestamp']}") | |
st.write(f"Results: {len(item['results'])} repositories") | |
if st.button("Rerun this search", key=f"rerun_{idx}"): | |
st.session_state.rerun_query = item['query'] | |
else: | |
st.write("No search history yet") | |
st.header("π Usage Statistics") | |
st.write(f"Total Searches: {len(st.session_state.history)}") | |
if st.session_state.feedback: | |
total_likes = sum(f['likes'] for f in st.session_state.feedback.values()) | |
total_dislikes = sum(f['dislikes'] for f in st.session_state.feedback.values()) | |
st.write(f"Total Likes: {total_likes}") | |
st.write(f"Total Dislikes: {total_dislikes}") | |
# Load resources | |
def initialize_resources(): | |
data, tokenizer, model = load_data_and_model() | |
return data, tokenizer, model | |
data, tokenizer, model = initialize_resources() | |
# Main interface | |
user_query = st.text_area( | |
"Describe your project:", | |
height=150, | |
placeholder="Example: I need a machine learning project for customer churn prediction..." | |
) | |
# Search button and filters | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
search_button = st.button("π Search Repositories", type="primary") | |
with col2: | |
top_n = st.selectbox("Number of results:", [3, 5, 10], index=1) | |
if search_button and user_query: | |
with st.spinner("Finding relevant repositories..."): | |
# Generate query embedding and get recommendations | |
query_embedding = generate_embedding(model, tokenizer, user_query) | |
data['similarity'] = data['embedding'].apply( | |
lambda x: cosine_similarity([query_embedding], [x])[0][0] | |
) | |
recommendations = data.nlargest(top_n, 'similarity') | |
# Save to history | |
st.session_state.history.append({ | |
'query': user_query, | |
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
'results': recommendations['repo'].tolist() | |
}) | |
# Display recommendations | |
st.markdown("### π― Top Recommendations") | |
for idx, row in recommendations.iterrows(): | |
with st.expander(f"Repository {idx + 1}: {row['repo']}", expanded=True): | |
# Repository details | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.markdown(f"**URL:** [View Repository]({row['url']})") | |
st.markdown(f"**Path:** `{row['path']}`") | |
with col2: | |
st.metric("Match Score", f"{row['similarity']:.2%}") | |
# Feedback buttons | |
feedback_col1, feedback_col2 = st.columns(2) | |
with feedback_col1: | |
if st.button("π", key=f"like_{idx}"): | |
save_feedback(row['repo'], 'likes') | |
st.success("Thanks for your feedback!") | |
with feedback_col2: | |
if st.button("π", key=f"dislike_{idx}"): | |
save_feedback(row['repo'], 'dislikes') | |
st.success("Thanks for your feedback!") | |
# Case Study Tab | |
with st.expander("π Case Study Brief"): | |
st.markdown(generate_case_study(row)) | |
# Documentation Tab | |
if row['docstring']: | |
with st.expander("π Documentation"): | |
st.markdown(row['docstring']) | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
""" | |
Made with π€ using CodeT5 and Streamlit | | |
GPU Status: {'π’ Enabled' if torch.cuda.is_available() else 'π΄ Disabled'} | | |
Model: CodeT5-Small | |
""" | |
) |