Update modeling_prot2text.py
Browse files- modeling_prot2text.py +28 -6
modeling_prot2text.py
CHANGED
@@ -123,9 +123,17 @@ class Prot2TextModel(PreTrainedModel):
|
|
123 |
|
124 |
@torch.no_grad()
|
125 |
def generate_protein_description(self,
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
):
|
130 |
|
131 |
if self.config.esm and not self.config.rgcn and protein_sequence==None:
|
@@ -147,9 +155,23 @@ class Prot2TextModel(PreTrainedModel):
|
|
147 |
inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
148 |
encoder_state = dict()
|
149 |
encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
@torch.no_grad()
|
155 |
def generate(self,
|
|
|
123 |
|
124 |
@torch.no_grad()
|
125 |
def generate_protein_description(self,
|
126 |
+
protein_sequence=None,
|
127 |
+
tokenizer=None,
|
128 |
+
device='cpu',
|
129 |
+
streamer=None,
|
130 |
+
max_new_tokens=None,
|
131 |
+
do_sample=None,
|
132 |
+
top_p=None,
|
133 |
+
top_k=None,
|
134 |
+
temperature=None,
|
135 |
+
num_beams=1,
|
136 |
+
repetition_penalty=None
|
137 |
):
|
138 |
|
139 |
if self.config.esm and not self.config.rgcn and protein_sequence==None:
|
|
|
155 |
inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
156 |
encoder_state = dict()
|
157 |
encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
|
158 |
+
if streamer is None:
|
159 |
+
generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
|
160 |
+
return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
|
161 |
+
else:
|
162 |
+
return self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
163 |
+
encoder_outputs=encoder_state,
|
164 |
+
use_cache=True,
|
165 |
+
streamer=streamer,
|
166 |
+
max_new_tokens=max_new_tokens,
|
167 |
+
do_sample=do_sample,
|
168 |
+
top_p=top_p,
|
169 |
+
top_k=top_k,
|
170 |
+
temperature=temperature,
|
171 |
+
num_beams=1,
|
172 |
+
repetition_penalty=repetition_penalty)
|
173 |
+
|
174 |
+
|
175 |
|
176 |
@torch.no_grad()
|
177 |
def generate(self,
|