Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1deINvEblsMkv9h0gJzuGB4uSamW0DMX5 | |
""" | |
pip install streamlit transformers gdown torch pandas numpy | |
import warnings | |
warnings.filterwarnings('ignore') | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import gdown | |
from pathlib import Path | |
from datetime import datetime | |
import json | |
import torch.cuda | |
# Configure GPU if available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Initialize session state | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
if 'feedback' not in st.session_state: | |
st.session_state.feedback = {} | |
# Step 1: Optimized Model Loading | |
def load_model_and_tokenizer(): | |
""" | |
Optimized model loading with GPU support and model quantization | |
""" | |
model_name = "Salesforce/codet5-small" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with optimizations | |
model = AutoModel.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True | |
) | |
# Move model to GPU if available | |
model = model.to(device) | |
# Set to evaluation mode | |
model.eval() | |
return tokenizer, model | |
# Step 2: Optimized Dataset Loading | |
def load_data(): | |
""" | |
Load and prepare dataset with progress tracking | |
""" | |
Path("data").mkdir(exist_ok=True) | |
dataset_path = "/content/drive/MyDrive/practice_ml/filtered_dataset.parquet" | |
if not Path(dataset_path).exists(): | |
with st.spinner('Downloading dataset... This might take a few minutes...'): | |
url = "https://drive.google.com/drive/folders/1dphd3vDKV46GwWKW5uo-MBl0GWGyCWUs?usp=drive_link" | |
gdown.download(url, dataset_path, quiet=False) | |
data = pd.read_parquet(dataset_path) | |
data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('') | |
return data | |
# Step 3: Optimized Embedding Generation | |
def generate_embedding(_model, tokenizer, text): | |
""" | |
Generate embeddings with optimized batch processing | |
""" | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(device) | |
with torch.no_grad(): | |
outputs = _model.encoder(**inputs) | |
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
def generate_case_study(repo_data): | |
""" | |
Generate a concise case study brief from repository data | |
""" | |
template = f""" | |
**Project Overview**: {repo_data['summary'][:50]}... | |
**Key Features**: | |
- Repository contains production-ready {repo_data['path'].split('/')[-1]} implementation | |
- {repo_data['docstring'][:50]}... | |
**Potential Applications**: This repository can be utilized for projects requiring {repo_data['summary'].split()[0:5]}... | |
**Implementation Complexity**: {'Medium' if len(repo_data['docstring']) > 500 else 'Low'} | |
**Integration Potential**: {'High' if 'api' in repo_data['text'].lower() or 'interface' in repo_data['text'].lower() else 'Medium'} | |
""" | |
return template[:150] + "..." | |
def save_feedback(repo_id, feedback_type): | |
""" | |
Save user feedback for a repository | |
""" | |
if repo_id not in st.session_state.feedback: | |
st.session_state.feedback[repo_id] = {'likes': 0, 'dislikes': 0} | |
st.session_state.feedback[repo_id][feedback_type] += 1 | |
# Main App | |
st.title("Enhanced Repository Recommender System π") | |
# Sidebar for History and Stats | |
with st.sidebar: | |
st.header("π Search History") | |
if st.session_state.history: | |
for idx, item in enumerate(st.session_state.history[-5:]): # Show last 5 searches | |
with st.expander(f"Search {len(st.session_state.history)-idx}: {item['query'][:30]}..."): | |
st.write(f"Time: {item['timestamp']}") | |
st.write(f"Results: {len(item['results'])} repositories") | |
if st.button("Rerun this search", key=f"rerun_{idx}"): | |
st.session_state.rerun_query = item['query'] | |
else: | |
st.write("No search history yet") | |
st.header("π Usage Statistics") | |
st.write(f"Total Searches: {len(st.session_state.history)}") | |
if st.session_state.feedback: | |
total_likes = sum(f['likes'] for f in st.session_state.feedback.values()) | |
total_dislikes = sum(f['dislikes'] for f in st.session_state.feedback.values()) | |
st.write(f"Total Likes: {total_likes}") | |
st.write(f"Total Dislikes: {total_dislikes}") | |
# Load resources | |
def initialize_resources(): | |
data = load_data() | |
tokenizer, model = load_model_and_tokenizer() | |
return data, tokenizer, model | |
data, tokenizer, model = initialize_resources() | |
# Main interface | |
user_query = st.text_area( | |
"Describe your project:", | |
height=150, | |
placeholder="Example: I need a machine learning project for customer churn prediction..." | |
) | |
# Search button and filters | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
search_button = st.button("π Search Repositories", type="primary") | |
with col2: | |
top_n = st.selectbox("Number of results:", [3, 5, 10], index=1) | |
if search_button and user_query: | |
with st.spinner("Finding relevant repositories..."): | |
# Generate query embedding and get recommendations | |
query_embedding = generate_embedding(model, tokenizer, user_query) | |
data['similarity'] = data['embedding'].apply( | |
lambda x: cosine_similarity([query_embedding], [x])[0][0] | |
) | |
recommendations = data.nlargest(top_n, 'similarity') | |
# Save to history | |
st.session_state.history.append({ | |
'query': user_query, | |
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
'results': recommendations['repo'].tolist() | |
}) | |
# Display recommendations | |
st.markdown("### π― Top Recommendations") | |
for idx, row in recommendations.iterrows(): | |
with st.expander(f"Repository {idx + 1}: {row['repo']}", expanded=True): | |
# Repository details | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.markdown(f"**URL:** [View Repository]({row['url']})") | |
st.markdown(f"**Path:** `{row['path']}`") | |
with col2: | |
st.metric("Match Score", f"{row['similarity']:.2%}") | |
# Feedback buttons | |
feedback_col1, feedback_col2 = st.columns(2) | |
with feedback_col1: | |
if st.button("π", key=f"like_{idx}"): | |
save_feedback(row['repo'], 'likes') | |
st.success("Thanks for your feedback!") | |
with feedback_col2: | |
if st.button("π", key=f"dislike_{idx}"): | |
save_feedback(row['repo'], 'dislikes') | |
st.success("Thanks for your feedback!") | |
# Case Study Tab | |
with st.expander("π Case Study Brief"): | |
st.markdown(generate_case_study(row)) | |
# Documentation Tab | |
if row['docstring']: | |
with st.expander("π Documentation"): | |
st.markdown(row['docstring']) | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
""" | |
Made with π€ using CodeT5 and Streamlit | | |
GPU Status: {'π’ Enabled' if torch.cuda.is_available() else 'π΄ Disabled'} | | |
Model: CodeT5-Small | |
""" | |
) |