frankjosh commited on
Commit
ca5a024
·
verified ·
1 Parent(s): 573ba9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -33
app.py CHANGED
@@ -24,6 +24,7 @@ from datetime import datetime
24
  import json
25
  import torch.cuda
26
  import os
 
27
 
28
  # Configure GPU if available
29
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -35,34 +36,12 @@ if 'feedback' not in st.session_state:
35
  st.session_state.feedback = {}
36
 
37
 
38
- # Configuration
39
- DATASET_GDRIVE_ID = "1pPYlUEtIA3bi8iLVKqzF-37sHoaOhTZz" # Replace with your actual file ID
40
- LOCAL_DATA_DIR = "data"
41
- DATASET_FILENAME = "filtered_dataset.parquet"
42
-
43
- def download_from_gdrive():
44
- """
45
- Download dataset from Google Drive with proper error handling
46
- """
47
- os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
48
- local_path = os.path.join(LOCAL_DATA_DIR, DATASET_FILENAME)
49
-
50
- if not os.path.exists(local_path):
51
- try:
52
- with st.spinner('Downloading dataset from Google Drive... This might take a few minutes...'):
53
- # Create direct download URL
54
- url = f'https://drive.google.com/uc?id={DATASET_GDRIVE_ID}'
55
- # Download file
56
- gdown.download(url, local_path, quiet=False)
57
- if os.path.exists(local_path):
58
- st.success("Dataset downloaded successfully!")
59
- else:
60
- st.error("Failed to download dataset")
61
- st.stop()
62
- except Exception as e:
63
- st.error(f"Error downloading dataset: {str(e)}")
64
- st.stop()
65
- return local_path
66
 
67
  # Step 1: Load Dataset and Precompute Embeddings
68
  @st.cache_resource
@@ -72,17 +51,20 @@ def load_data_and_model():
72
  """
73
  try:
74
  # Download and load dataset
75
- dataset_path = download_from_gdrive()
76
- data = pd.read_parquet(dataset_path)
77
  except Exception as e:
78
  st.error(f"Error loading dataset: {str(e)}")
79
  st.stop()
80
 
81
- # Combine text fields for embedding generation
82
- data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
83
-
84
  # Load CodeT5-small model and tokenizer
85
  model_name = "Salesforce/codet5-small"
 
 
 
 
 
 
86
 
87
  @st.cache_resource
88
  def load_model_and_tokenizer():
 
24
  import json
25
  import torch.cuda
26
  import os
27
+ from datasets import load_dataset
28
 
29
  # Configure GPU if available
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
36
  st.session_state.feedback = {}
37
 
38
 
39
+ @st.cache_data
40
+ def generate_embedding(model, tokenizer, text):
41
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Step 1: Load Dataset and Precompute Embeddings
47
  @st.cache_resource
 
51
  """
52
  try:
53
  # Download and load dataset
54
+ dataset = load_dataset("frankjosh/filtered_dataset")
55
+ data = pd.DataFrame(dataset['train'])
56
  except Exception as e:
57
  st.error(f"Error loading dataset: {str(e)}")
58
  st.stop()
59
 
 
 
 
60
  # Load CodeT5-small model and tokenizer
61
  model_name = "Salesforce/codet5-small"
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ model = AutoTokenizer.from_pretrained(model_name)
64
+
65
+ # Combine text fields for embedding generation
66
+ data['text'] = data['docstring'].fillna('') + ' ' + data['summary'].fillna('')
67
+ return data, tokenizer, model
68
 
69
  @st.cache_resource
70
  def load_model_and_tokenizer():