habdine commited on
Commit
71ad8b3
·
verified ·
1 Parent(s): b5bd029

Update modeling_prot2text.py

Browse files
Files changed (1) hide show
  1. modeling_prot2text.py +28 -6
modeling_prot2text.py CHANGED
@@ -123,9 +123,17 @@ class Prot2TextModel(PreTrainedModel):
123
 
124
  @torch.no_grad()
125
  def generate_protein_description(self,
126
- protein_sequence=None,
127
- tokenizer=None,
128
- device='cpu'
 
 
 
 
 
 
 
 
129
  ):
130
 
131
  if self.config.esm and not self.config.rgcn and protein_sequence==None:
@@ -147,9 +155,23 @@ class Prot2TextModel(PreTrainedModel):
147
  inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
148
  encoder_state = dict()
149
  encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
150
- generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
151
-
152
- return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  @torch.no_grad()
155
  def generate(self,
 
123
 
124
  @torch.no_grad()
125
  def generate_protein_description(self,
126
+ protein_sequence=None,
127
+ tokenizer=None,
128
+ device='cpu',
129
+ streamer=None,
130
+ max_new_tokens=None,
131
+ do_sample=None,
132
+ top_p=None,
133
+ top_k=None,
134
+ temperature=None,
135
+ num_beams=1,
136
+ repetition_penalty=None
137
  ):
138
 
139
  if self.config.esm and not self.config.rgcn and protein_sequence==None:
 
155
  inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
156
  encoder_state = dict()
157
  encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
158
+ if streamer is None:
159
+ generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
160
+ return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
161
+ else:
162
+ return self.decoder.generate(input_ids=inputs['decoder_input_ids'],
163
+ encoder_outputs=encoder_state,
164
+ use_cache=True,
165
+ streamer=streamer,
166
+ max_new_tokens=max_new_tokens,
167
+ do_sample=do_sample,
168
+ top_p=top_p,
169
+ top_k=top_k,
170
+ temperature=temperature,
171
+ num_beams=1,
172
+ repetition_penalty=repetition_penalty)
173
+
174
+
175
 
176
  @torch.no_grad()
177
  def generate(self,