ccm commited on
Commit
936fa72
·
verified ·
1 Parent(s): a81897d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -7
main.py CHANGED
@@ -24,13 +24,11 @@ data.reset_index(inplace=True)
24
  # Create a FAISS index for fast similarity search
25
  metric = faiss.METRIC_INNER_PRODUCT
26
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
27
- gpu_index = faiss.IndexFlatL2(len(data["embedding"][0]))
28
- # res = faiss.StandardGpuResources() # use a single GPU
29
- # gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
30
- gpu_index.metric_type = metric
31
  faiss.normalize_L2(vectors)
32
- gpu_index.train(vectors)
33
- gpu_index.add(vectors)
34
 
35
  # Load the model for later use in embeddings
36
  model = sentence_transformers.SentenceTransformer("allenai-specter")
@@ -39,7 +37,7 @@ model = sentence_transformers.SentenceTransformer("allenai-specter")
39
  def search(query: str, k: int) -> tuple[str]:
40
  query = numpy.expand_dims(model.encode(query), axis=0)
41
  faiss.normalize_L2(query)
42
- D, I = gpu_index.search(query, k)
43
  top_five = data.loc[I[0]]
44
 
45
  search_results = "You are an AI assistant who delights in helping people" \
@@ -78,6 +76,7 @@ def postprocess(response: str, bypass_from_preprocessing: str) -> str:
78
  """Applies a postprocessing step to the LLM's response before the user receives it"""
79
  return response + bypass_from_preprocessing
80
 
 
81
  def predict(message: str, history: list[str]) -> str:
82
  """This function is responsible for crafting a response"""
83
 
 
24
  # Create a FAISS index for fast similarity search
25
  metric = faiss.METRIC_INNER_PRODUCT
26
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
27
+ index = faiss.IndexFlatL2(len(data["embedding"][0]))
28
+ index.metric_type = metric
 
 
29
  faiss.normalize_L2(vectors)
30
+ index.train(vectors)
31
+ index.add(vectors)
32
 
33
  # Load the model for later use in embeddings
34
  model = sentence_transformers.SentenceTransformer("allenai-specter")
 
37
  def search(query: str, k: int) -> tuple[str]:
38
  query = numpy.expand_dims(model.encode(query), axis=0)
39
  faiss.normalize_L2(query)
40
+ D, I = index.search(query, k)
41
  top_five = data.loc[I[0]]
42
 
43
  search_results = "You are an AI assistant who delights in helping people" \
 
76
  """Applies a postprocessing step to the LLM's response before the user receives it"""
77
  return response + bypass_from_preprocessing
78
 
79
+ @spaces.GPU
80
  def predict(message: str, history: list[str]) -> str:
81
  """This function is responsible for crafting a response"""
82