Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -150,16 +150,17 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
|
|
150 |
|
151 |
@spaces.GPU()
|
152 |
def generate_pali(user_emb):
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
163 |
return decoded
|
164 |
|
165 |
|
|
|
150 |
|
151 |
@spaces.GPU()
|
152 |
def generate_pali(user_emb):
|
153 |
+
with torch.no_grad():
|
154 |
+
prompt = 'caption en'
|
155 |
+
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
156 |
+
# we need to get im_embs taken in here.
|
157 |
+
input_len = model_inputs["input_ids"].shape[-1]
|
158 |
+
input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
|
159 |
+
model_inputs["input_ids"].to(device),
|
160 |
+
model_inputs["attention_mask"].to(device))
|
161 |
+
|
162 |
+
generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
163 |
+
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
164 |
return decoded
|
165 |
|
166 |
|