frankjosh commited on
Commit
5e9f512
·
verified ·
1 Parent(s): 5f5d654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -19
app.py CHANGED
@@ -78,6 +78,8 @@ def load_data_and_model():
78
 
79
  tokenizer, model = load_model_and_tokenizer()
80
 
 
 
81
  @st.cache_data
82
  def generate_embedding(_model, _tokenizer, text):
83
  inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
@@ -90,34 +92,26 @@ def generate_embedding(_model, _tokenizer, text):
90
  embedding = embedding.cpu()
91
  return embedding.numpy()
92
 
93
- #error handling
94
  try:
95
  query_embedding = generate_embedding(model, tokenizer, user_query)
96
  except Exception as e:
97
  st.error(f"Error generating embedding: {str(e)}")
98
  st.stop()
99
 
100
- # Precompute embeddings with GPU support
 
101
  @st.cache_data
102
- def generate_embedding(text):
103
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
104
- # Move inputs to GPU if available
105
- if torch.cuda.is_available():
106
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
107
- with torch.no_grad():
108
- outputs = model.encoder(**inputs)
109
- # Move output back to CPU if needed
110
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
111
- if torch.cuda.is_available():
112
- embedding = embedding.cpu()
113
- return embedding.numpy()
114
-
115
- # Generate embeddings with progress bar
116
  with st.spinner('Generating embeddings... This might take a few minutes on first run...'):
117
- data['embedding'] = data['text'].apply(lambda x: generate_embedding(x))
118
-
119
- return data, tokenizer, model
120
 
 
 
121
 
122
  def generate_case_study(repo_data):
123
  """
 
78
 
79
  tokenizer, model = load_model_and_tokenizer()
80
 
81
+
82
+ # Define the embedding generation function
83
  @st.cache_data
84
  def generate_embedding(_model, _tokenizer, text):
85
  inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
92
  embedding = embedding.cpu()
93
  return embedding.numpy()
94
 
95
+ # Error handling for generating query embeddings
96
  try:
97
  query_embedding = generate_embedding(model, tokenizer, user_query)
98
  except Exception as e:
99
  st.error(f"Error generating embedding: {str(e)}")
100
  st.stop()
101
 
102
+ # Precompute embeddings for dataset
103
+ def precompute_embeddings(data, model, tokenizer):
104
  @st.cache_data
105
+ def generate_cached_embedding(text):
106
+ return generate_embedding(model, tokenizer, text)
107
+
108
+ # Apply embedding generation with progress bar
 
 
 
 
 
 
 
 
 
 
109
  with st.spinner('Generating embeddings... This might take a few minutes on first run...'):
110
+ data['embedding'] = data['text'].apply(lambda x: generate_cached_embedding(x))
111
+ return data
 
112
 
113
+ # Example usage:
114
+ # data = precompute_embeddings(data, model, tokenizer)
115
 
116
  def generate_case_study(repo_data):
117
  """