Spaces:
Running
on
Zero
Running
on
Zero
Update videomind/model/model.py
Browse files- videomind/model/model.py +5 -1
videomind/model/model.py
CHANGED
@@ -18,6 +18,10 @@ from .generator import PointGenerator
|
|
18 |
from .loss import BundleLoss
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
21 |
class AgentQwen2VLConfig(Qwen2VLConfig):
|
22 |
model_type = 'agent_qwen2_vl'
|
23 |
|
@@ -52,7 +56,7 @@ class AgentQwen2VLModel(Qwen2VLModel):
|
|
52 |
|
53 |
def __init__(self, config):
|
54 |
super().__init__(config)
|
55 |
-
self.norm.register_forward_pre_hook(
|
56 |
|
57 |
def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
|
58 |
# ensure gradient tracking (in case that embed_tokens has been frozen)
|
|
|
18 |
from .loss import BundleLoss
|
19 |
|
20 |
|
21 |
+
def cache_state_hook(module, args):
|
22 |
+
module.state = args[0]
|
23 |
+
|
24 |
+
|
25 |
class AgentQwen2VLConfig(Qwen2VLConfig):
|
26 |
model_type = 'agent_qwen2_vl'
|
27 |
|
|
|
56 |
|
57 |
def __init__(self, config):
|
58 |
super().__init__(config)
|
59 |
+
self.norm.register_forward_pre_hook(cache_state_hook)
|
60 |
|
61 |
def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
|
62 |
# ensure gradient tracking (in case that embed_tokens has been frozen)
|