rynmurdock commited on
Commit
5f154d0
·
verified ·
1 Parent(s): 59bcd87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -150,16 +150,17 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
150
 
151
  @spaces.GPU()
152
  def generate_pali(user_emb):
153
- prompt = 'caption en'
154
- model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
155
- # we need to get im_embs taken in here.
156
- input_len = model_inputs["input_ids"].shape[-1]
157
- input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
158
- model_inputs["input_ids"].to(device),
159
- model_inputs["attention_mask"].to(device))
160
-
161
- generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
162
- decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
163
  return decoded
164
 
165
 
 
150
 
151
  @spaces.GPU()
152
  def generate_pali(user_emb):
153
+ with torch.no_grad():
154
+ prompt = 'caption en'
155
+ model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
156
+ # we need to get im_embs taken in here.
157
+ input_len = model_inputs["input_ids"].shape[-1]
158
+ input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
159
+ model_inputs["input_ids"].to(device),
160
+ model_inputs["attention_mask"].to(device))
161
+
162
+ generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
163
+ decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
164
  return decoded
165
 
166