venkat-natchi commited on
Commit
3249072
·
verified ·
1 Parent(s): 3c70c98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -61,15 +61,16 @@ audio_pipe = pipeline(
61
 
62
 
63
  def process_text(text, count):
64
- inputs = text_tokenizer(text, return_tensors="pt",
65
- return_attention_mask=False)
 
66
  prediction = text_tokenizer.batch_decode(
67
- base_phi2_text.generate(
68
- **inputs,
69
- max_new_tokens=count,
70
- bos_token_id=text_tokenizer.bos_token_id,
71
- eos_token_id=text_tokenizer.eos_token_id,
72
- pad_token_id=text_tokenizer.pad_token_id
73
  )
74
  )
75
  return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
@@ -103,7 +104,7 @@ def generate_response(image, audio, text, count):
103
  q_tokens = text_tokenizer.encode(
104
  overall_input,
105
  return_tensors='pt').to(device)
106
- question_token_embeddings = base_phi2_text.get_submodule(
107
  'model.embed_tokens')(q_tokens).to(device)
108
  inputs = torch.concat(
109
  (img_tokens.unsqueeze(0), question_token_embeddings),
 
61
 
62
 
63
  def process_text(text, count):
64
+ inputs = text_tokenizer.encode(text, return_tensors="pt")
65
+ input_embeds = tuned_phi2.get_submodule(
66
+ 'model.embed_tokens')(inputs).to(device)
67
  prediction = text_tokenizer.batch_decode(
68
+ tuned_phi2.generate(
69
+ inputs_embeds=inputs,
70
+ max_new_tokens=30,
71
+ bos_token_id=text_tokenizer.bos_token_id,
72
+ eos_token_id=text_tokenizer.eos_token_id,
73
+ pad_token_id=text_tokenizer.pad_token_id
74
  )
75
  )
76
  return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
 
104
  q_tokens = text_tokenizer.encode(
105
  overall_input,
106
  return_tensors='pt').to(device)
107
+ question_token_embeddings = tuned_phi2.get_submodule(
108
  'model.embed_tokens')(q_tokens).to(device)
109
  inputs = torch.concat(
110
  (img_tokens.unsqueeze(0), question_token_embeddings),