tcy6 commited on
Commit
95be598
·
1 Parent(s): 4eb0919

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -19,7 +19,7 @@ import json
19
  cache_dir = '/data/KB'
20
  os.makedirs(cache_dir, exist_ok=True)
21
 
22
- @spaces.GPU
23
  def weighted_mean_pooling(hidden, attention_mask):
24
  attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
25
  s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
@@ -27,7 +27,7 @@ def weighted_mean_pooling(hidden, attention_mask):
27
  reps = s / d
28
  return reps
29
 
30
- @spaces.GPU
31
  @torch.no_grad()
32
  def encode(text_or_image_list):
33
  global model, tokenizer
@@ -108,7 +108,7 @@ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
108
 
109
  return knowledge_base_name
110
 
111
- @spaces.GPU
112
  def retrieve_gradio(knowledge_base: str, query: str, topk: int):
113
  global model, tokenizer
114
 
 
19
  cache_dir = '/data/KB'
20
  os.makedirs(cache_dir, exist_ok=True)
21
 
22
+ @spaces.GPU(duration=100)
23
  def weighted_mean_pooling(hidden, attention_mask):
24
  attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
25
  s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
 
27
  reps = s / d
28
  return reps
29
 
30
+ @spaces.GPU(duration=100)
31
  @torch.no_grad()
32
  def encode(text_or_image_list):
33
  global model, tokenizer
 
108
 
109
  return knowledge_base_name
110
 
111
+ @spaces.GPU(duration=100)
112
  def retrieve_gradio(knowledge_base: str, query: str, topk: int):
113
  global model, tokenizer
114