Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.py | |
Enhanced Repository Recommender System using Streamlit and CodeT5-small | |
""" | |
import warnings | |
warnings.filterwarnings('ignore') | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from tqdm import tqdm | |
from datasets import load_dataset | |
from datetime import datetime | |
# Configure GPU if available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Initialize session state | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
if 'feedback' not in st.session_state: | |
st.session_state.feedback = {} | |
# Step 1: Load Dataset and Precompute Embeddings | |
def load_data_and_model(): | |
""" | |
Load the dataset and precompute embeddings. Load the CodeT5-small model and tokenizer. | |
""" | |
try: | |
# Download and load dataset | |
dataset = load_dataset("frankjosh/filtered_dataset") | |
data = pd.DataFrame(dataset['train']) | |
# Ensure required columns exist | |
required_columns = ['docstring', 'summary'] | |
for col in required_columns: | |
if col not in data.columns: | |
st.error(f"Missing required column: {col}") | |
st.stop() | |
# Combine text fields for embedding generation | |
data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('') | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
st.stop() | |
# Load CodeT5-small model and tokenizer | |
model_name = "Salesforce/codet5-small" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
# Move model to GPU if available | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
model.eval() # Set to evaluation mode | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.stop() | |
return data, tokenizer, model | |
# Define the embedding generation function | |
def generate_embedding(_model, _tokenizer, text): | |
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
if torch.cuda.is_available(): | |
inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = _model.encoder(**inputs) | |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze() | |
if torch.cuda.is_available(): | |
embedding = embedding.cpu() | |
return embedding.numpy() | |
# Precompute embeddings for dataset | |
def precompute_embeddings(data, model, tokenizer): | |
embeddings = [] | |
for text in tqdm(data['text'], desc="Generating embeddings"): | |
embedding = generate_embedding(model, tokenizer, text) | |
embeddings.append(embedding) | |
data['embedding'] = embeddings | |
return data | |
# Generate a concise case study brief from repository data | |
def generate_case_study(repo_data): | |
template = f""" | |
**Project Overview**: {repo_data['summary'][:50]}... | |
**Key Features**: | |
- Repository contains production-ready {repo_data['path'].split('/')[-1]} implementation | |
- {repo_data['docstring'][:50]}... | |
**Potential Applications**: This repository can be utilized for projects requiring {' '.join(repo_data['summary'].split()[:5])}... | |
**Implementation Complexity**: {'Medium' if len(repo_data['docstring']) > 500 else 'Low'} | |
**Integration Potential**: {'High' if 'api' in repo_data['text'].lower() or 'interface' in repo_data['text'].lower() else 'Medium'} | |
""" | |
return template[:150] + "..." | |
# Save user feedback for a repository | |
def save_feedback(repo_id, feedback_type): | |
if repo_id not in st.session_state.feedback: | |
st.session_state.feedback[repo_id] = {'likes': 0, 'dislikes': 0} | |
st.session_state.feedback[repo_id][feedback_type] += 1 | |
# Load resources | |
data, tokenizer, model = load_data_and_model() | |
data = precompute_embeddings(data, model, tokenizer) | |
# Main App Interface | |
st.title("Enhanced Repository Recommender System π") | |
# Sidebar for History and Stats | |
with st.sidebar: | |
st.header("π Search History") | |
if st.session_state.history: | |
for idx, item in enumerate(st.session_state.history[-5:]): # Show last 5 searches | |
with st.expander(f"Search {len(st.session_state.history)-idx}: {item['query'][:30]}..."): | |
st.write(f"Time: {item['timestamp']}") | |
st.write(f"Results: {len(item['results'])} repositories") | |
if st.button("Rerun this search", key=f"rerun_{idx}"): | |
st.session_state.rerun_query = item['query'] | |
else: | |
st.write("No search history yet") | |
st.header("π Usage Statistics") | |
st.write(f"Total Searches: {len(st.session_state.history)}") | |
if st.session_state.feedback: | |
feedback_df = pd.DataFrame(st.session_state.feedback).T | |
feedback_df['Total'] = feedback_df['likes'] + feedback_df['dislikes'] | |
st.bar_chart(feedback_df[['likes', 'dislikes']]) | |
# Main interface | |
user_query = st.text_area( | |
"Describe your project:", | |
height=150, | |
placeholder="Example: I need a machine learning project for customer churn prediction..." | |
) | |
# Search button and filters | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
search_button = st.button("π Search Repositories", type="primary") | |
with col2: | |
top_n = st.selectbox("Number of results:", [3, 5, 10], index=1) | |
if search_button and user_query.strip(): | |
with st.spinner("Finding relevant repositories..."): | |
# Generate query embedding and get recommendations | |
query_embedding = generate_embedding(model, tokenizer, user_query) | |
data['similarity'] = data['embedding'].apply( | |
lambda x: cosine_similarity([query_embedding], [x])[0][0] | |
) | |
recommendations = data.nlargest(top_n, 'similarity') | |
# Save to history | |
st.session_state.history.append({ | |
'query': user_query, | |
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
'results': recommendations['repo'].tolist() | |
}) | |
# Display recommendations | |
st.markdown("### π― Top Recommendations") | |
for idx, row in recommendations.iterrows(): | |
with st.expander(f"Repository {idx + 1}: {row['repo']}", expanded=True): | |
# Repository details | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.markdown(f"**URL:** [View Repository]({row['url']})") | |
st.markdown(f"**Path:** `{row['path']}`") | |
with col2: | |
st.metric("Match Score", f"{row['similarity']:.2%}") | |
# Feedback buttons | |
feedback_col1, feedback_col2 = st.columns(2) | |
with feedback_col1: | |
if st.button("π", key=f"like_{idx}"): | |
save_feedback(row['repo'], 'likes') | |
st.success("Thanks for your feedback!") | |
with feedback_col2: | |
if st.button("π", key=f"dislike_{idx}"): | |
save_feedback(row['repo'], 'dislikes') | |
st.success("Thanks for your feedback!") | |
# Case Study Tab | |
with st.expander("π Case Study Brief"): | |
st.markdown(generate_case_study(row)) | |
# Documentation Tab | |
if row['docstring']: | |
with st.expander("π Documentation"): | |
st.markdown(row['docstring']) | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
""" | |
Made with π€ using CodeT5 and Streamlit | | |
GPU Status: {'π’ Enabled' if torch.cuda.is_available() else 'π΄ Disabled'} | | |
Model: CodeT5-Small | |
""" | |
) | |