m-ric HF Staff commited on
Commit
e5a9bbf
Β·
verified Β·
1 Parent(s): 7dc5ad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -58,13 +58,10 @@ def generate_and_visualize(prompt, num_tokens=10):
58
  output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
59
  max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
60
 
61
- max_logits.backward(max_logits)
62
- try:
63
- relevance = input_embeds.grad.float().sum(-1).cpu()[0]
64
- all_relevances.append(relevance)
65
- except:
66
- all_relevances.append(0)
67
-
68
  next_token = max_indices.unsqueeze(0)
69
  generated_tokens_ids.append(next_token.item())
70
 
 
58
  output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
59
  max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
60
 
61
+ max_logits.backward(max_logits)
62
+ relevance = input_embeds.grad.float().sum(-1).cpu()[0]
63
+ all_relevances.append(relevance)
64
+
 
 
 
65
  next_token = max_indices.unsqueeze(0)
66
  generated_tokens_ids.append(next_token.item())
67