Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,12 @@ warnings.filterwarnings('ignore')
|
|
3 |
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
6 |
-
from datasets import load_dataset
|
7 |
import numpy as np
|
8 |
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
from transformers import AutoTokenizer, AutoModel
|
10 |
import torch
|
11 |
from torch.utils.data import DataLoader, Dataset
|
|
|
12 |
from datetime import datetime
|
13 |
from typing import List, Dict, Any
|
14 |
from functools import partial
|
@@ -23,8 +23,9 @@ if 'history' not in st.session_state:
|
|
23 |
if 'feedback' not in st.session_state:
|
24 |
st.session_state.feedback = {}
|
25 |
|
26 |
-
# Define subset size
|
27 |
-
SUBSET_SIZE =
|
|
|
28 |
|
29 |
# Caching key resources: Model, Tokenizer, and Precomputed Embeddings
|
30 |
@st.cache_resource
|
@@ -51,9 +52,10 @@ def load_data():
|
|
51 |
return data
|
52 |
|
53 |
@st.cache_resource
|
54 |
-
def precompute_embeddings(data: pd.DataFrame,
|
55 |
"""
|
56 |
Precompute embeddings for repository metadata to optimize query performance.
|
|
|
57 |
"""
|
58 |
class TextDataset(Dataset):
|
59 |
def __init__(self, texts: List[str], tokenizer, max_length=512):
|
@@ -98,10 +100,10 @@ def precompute_embeddings(data: pd.DataFrame, tokenizer, model, batch_size=16):
|
|
98 |
outputs = model.encoder(**batch)
|
99 |
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
100 |
|
101 |
-
dataset = TextDataset(data['text'].tolist(),
|
102 |
dataloader = DataLoader(
|
103 |
dataset, batch_size=batch_size, shuffle=False,
|
104 |
-
collate_fn=partial(collate_fn, pad_token_id=
|
105 |
)
|
106 |
|
107 |
embeddings = []
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
|
|
6 |
import numpy as np
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from datasets import load_dataset # For loading dataset
|
12 |
from datetime import datetime
|
13 |
from typing import List, Dict, Any
|
14 |
from functools import partial
|
|
|
23 |
if 'feedback' not in st.session_state:
|
24 |
st.session_state.feedback = {}
|
25 |
|
26 |
+
# Define subset size and batch size for optimization
|
27 |
+
SUBSET_SIZE = 500 # Smaller subset for faster precomputation
|
28 |
+
BATCH_SIZE = 8 # Smaller batch size to reduce memory overhead
|
29 |
|
30 |
# Caching key resources: Model, Tokenizer, and Precomputed Embeddings
|
31 |
@st.cache_resource
|
|
|
52 |
return data
|
53 |
|
54 |
@st.cache_resource
|
55 |
+
def precompute_embeddings(data: pd.DataFrame, _tokenizer, model, batch_size=BATCH_SIZE):
|
56 |
"""
|
57 |
Precompute embeddings for repository metadata to optimize query performance.
|
58 |
+
The tokenizer is excluded from caching as it is unhashable.
|
59 |
"""
|
60 |
class TextDataset(Dataset):
|
61 |
def __init__(self, texts: List[str], tokenizer, max_length=512):
|
|
|
100 |
outputs = model.encoder(**batch)
|
101 |
return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
102 |
|
103 |
+
dataset = TextDataset(data['text'].tolist(), _tokenizer)
|
104 |
dataloader = DataLoader(
|
105 |
dataset, batch_size=batch_size, shuffle=False,
|
106 |
+
collate_fn=partial(collate_fn, pad_token_id=_tokenizer.pad_token_id)
|
107 |
)
|
108 |
|
109 |
embeddings = []
|