frankjosh commited on
Commit
1c4d662
·
verified ·
1 Parent(s): 0651d54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -3,12 +3,12 @@ warnings.filterwarnings('ignore')
3
 
4
  import streamlit as st
5
  import pandas as pd
6
- from datasets import load_dataset
7
  import numpy as np
8
  from sklearn.metrics.pairwise import cosine_similarity
9
  from transformers import AutoTokenizer, AutoModel
10
  import torch
11
  from torch.utils.data import DataLoader, Dataset
 
12
  from datetime import datetime
13
  from typing import List, Dict, Any
14
  from functools import partial
@@ -23,8 +23,9 @@ if 'history' not in st.session_state:
23
  if 'feedback' not in st.session_state:
24
  st.session_state.feedback = {}
25
 
26
- # Define subset size
27
- SUBSET_SIZE = 1000
 
28
 
29
  # Caching key resources: Model, Tokenizer, and Precomputed Embeddings
30
  @st.cache_resource
@@ -51,9 +52,10 @@ def load_data():
51
  return data
52
 
53
  @st.cache_resource
54
- def precompute_embeddings(data: pd.DataFrame, tokenizer, model, batch_size=16):
55
  """
56
  Precompute embeddings for repository metadata to optimize query performance.
 
57
  """
58
  class TextDataset(Dataset):
59
  def __init__(self, texts: List[str], tokenizer, max_length=512):
@@ -98,10 +100,10 @@ def precompute_embeddings(data: pd.DataFrame, tokenizer, model, batch_size=16):
98
  outputs = model.encoder(**batch)
99
  return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
100
 
101
- dataset = TextDataset(data['text'].tolist(), tokenizer)
102
  dataloader = DataLoader(
103
  dataset, batch_size=batch_size, shuffle=False,
104
- collate_fn=partial(collate_fn, pad_token_id=tokenizer.pad_token_id)
105
  )
106
 
107
  embeddings = []
 
3
 
4
  import streamlit as st
5
  import pandas as pd
 
6
  import numpy as np
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from transformers import AutoTokenizer, AutoModel
9
  import torch
10
  from torch.utils.data import DataLoader, Dataset
11
+ from datasets import load_dataset # For loading dataset
12
  from datetime import datetime
13
  from typing import List, Dict, Any
14
  from functools import partial
 
23
  if 'feedback' not in st.session_state:
24
  st.session_state.feedback = {}
25
 
26
+ # Define subset size and batch size for optimization
27
+ SUBSET_SIZE = 500 # Smaller subset for faster precomputation
28
+ BATCH_SIZE = 8 # Smaller batch size to reduce memory overhead
29
 
30
  # Caching key resources: Model, Tokenizer, and Precomputed Embeddings
31
  @st.cache_resource
 
52
  return data
53
 
54
  @st.cache_resource
55
+ def precompute_embeddings(data: pd.DataFrame, _tokenizer, model, batch_size=BATCH_SIZE):
56
  """
57
  Precompute embeddings for repository metadata to optimize query performance.
58
+ The tokenizer is excluded from caching as it is unhashable.
59
  """
60
  class TextDataset(Dataset):
61
  def __init__(self, texts: List[str], tokenizer, max_length=512):
 
100
  outputs = model.encoder(**batch)
101
  return outputs.last_hidden_state.mean(dim=1).cpu().numpy()
102
 
103
+ dataset = TextDataset(data['text'].tolist(), _tokenizer)
104
  dataloader = DataLoader(
105
  dataset, batch_size=batch_size, shuffle=False,
106
+ collate_fn=partial(collate_fn, pad_token_id=_tokenizer.pad_token_id)
107
  )
108
 
109
  embeddings = []