Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -78,6 +78,8 @@ def load_data_and_model():
|
|
78 |
|
79 |
tokenizer, model = load_model_and_tokenizer()
|
80 |
|
|
|
|
|
81 |
@st.cache_data
|
82 |
def generate_embedding(_model, _tokenizer, text):
|
83 |
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
@@ -90,34 +92,26 @@ def generate_embedding(_model, _tokenizer, text):
|
|
90 |
embedding = embedding.cpu()
|
91 |
return embedding.numpy()
|
92 |
|
93 |
-
#
|
94 |
try:
|
95 |
query_embedding = generate_embedding(model, tokenizer, user_query)
|
96 |
except Exception as e:
|
97 |
st.error(f"Error generating embedding: {str(e)}")
|
98 |
st.stop()
|
99 |
|
100 |
-
|
|
|
101 |
@st.cache_data
|
102 |
-
def
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
107 |
-
with torch.no_grad():
|
108 |
-
outputs = model.encoder(**inputs)
|
109 |
-
# Move output back to CPU if needed
|
110 |
-
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
|
111 |
-
if torch.cuda.is_available():
|
112 |
-
embedding = embedding.cpu()
|
113 |
-
return embedding.numpy()
|
114 |
-
|
115 |
-
# Generate embeddings with progress bar
|
116 |
with st.spinner('Generating embeddings... This might take a few minutes on first run...'):
|
117 |
-
data['embedding'] = data['text'].apply(lambda x:
|
118 |
-
|
119 |
-
return data, tokenizer, model
|
120 |
|
|
|
|
|
121 |
|
122 |
def generate_case_study(repo_data):
|
123 |
"""
|
|
|
78 |
|
79 |
tokenizer, model = load_model_and_tokenizer()
|
80 |
|
81 |
+
|
82 |
+
# Define the embedding generation function
|
83 |
@st.cache_data
|
84 |
def generate_embedding(_model, _tokenizer, text):
|
85 |
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
|
|
92 |
embedding = embedding.cpu()
|
93 |
return embedding.numpy()
|
94 |
|
95 |
+
# Error handling for generating query embeddings
|
96 |
try:
|
97 |
query_embedding = generate_embedding(model, tokenizer, user_query)
|
98 |
except Exception as e:
|
99 |
st.error(f"Error generating embedding: {str(e)}")
|
100 |
st.stop()
|
101 |
|
102 |
+
# Precompute embeddings for dataset
|
103 |
+
def precompute_embeddings(data, model, tokenizer):
|
104 |
@st.cache_data
|
105 |
+
def generate_cached_embedding(text):
|
106 |
+
return generate_embedding(model, tokenizer, text)
|
107 |
+
|
108 |
+
# Apply embedding generation with progress bar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
with st.spinner('Generating embeddings... This might take a few minutes on first run...'):
|
110 |
+
data['embedding'] = data['text'].apply(lambda x: generate_cached_embedding(x))
|
111 |
+
return data
|
|
|
112 |
|
113 |
+
# Example usage:
|
114 |
+
# data = precompute_embeddings(data, model, tokenizer)
|
115 |
|
116 |
def generate_case_study(repo_data):
|
117 |
"""
|