Update __main__.py
Browse files- __main__.py +8 -3
__main__.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import argparse
|
|
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
import torch.nn as nn
|
@@ -120,7 +122,7 @@ def answer_question(
|
|
120 |
except EOFError:
|
121 |
inp = ""
|
122 |
if not inp:
|
123 |
-
|
124 |
|
125 |
question = '<image>' + inp
|
126 |
|
@@ -163,6 +165,8 @@ def answer_question(
|
|
163 |
}
|
164 |
|
165 |
while True:
|
|
|
|
|
166 |
generated_ids = model.generate(
|
167 |
inputs_embeds=new_embeds,
|
168 |
attention_mask=attn_mask,
|
@@ -176,7 +180,8 @@ def answer_question(
|
|
176 |
except EOFError:
|
177 |
inp = ""
|
178 |
if not inp:
|
179 |
-
print("
|
|
|
180 |
|
181 |
new_text = generated_text + "<|start_header_id|>user<|end_header_id|>\n\n" + inp + "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
182 |
new_input_ids = tokenizer(new_text, return_tensors='pt').input_ids.to(device)
|
@@ -201,4 +206,4 @@ if __name__ == "__main__":
|
|
201 |
vision_model,
|
202 |
processor,
|
203 |
projection_module,
|
204 |
-
)
|
|
|
1 |
import argparse
|
2 |
+
import sys
|
3 |
+
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
122 |
except EOFError:
|
123 |
inp = ""
|
124 |
if not inp:
|
125 |
+
sys.exit("exiting..")
|
126 |
|
127 |
question = '<image>' + inp
|
128 |
|
|
|
165 |
}
|
166 |
|
167 |
while True:
|
168 |
+
print('assistant: ')
|
169 |
+
|
170 |
generated_ids = model.generate(
|
171 |
inputs_embeds=new_embeds,
|
172 |
attention_mask=attn_mask,
|
|
|
180 |
except EOFError:
|
181 |
inp = ""
|
182 |
if not inp:
|
183 |
+
print("exiting...")
|
184 |
+
break
|
185 |
|
186 |
new_text = generated_text + "<|start_header_id|>user<|end_header_id|>\n\n" + inp + "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
187 |
new_input_ids = tokenizer(new_text, return_tensors='pt').input_ids.to(device)
|
|
|
206 |
vision_model,
|
207 |
processor,
|
208 |
projection_module,
|
209 |
+
)
|