yeliudev commited on
Commit
655ce8d
·
verified ·
1 Parent(s): 6e4dd99

Update videomind/model/model.py

Browse files
Files changed (1) hide show
  1. videomind/model/model.py +5 -1
videomind/model/model.py CHANGED
@@ -18,6 +18,10 @@ from .generator import PointGenerator
18
  from .loss import BundleLoss
19
 
20
 
 
 
 
 
21
  class AgentQwen2VLConfig(Qwen2VLConfig):
22
  model_type = 'agent_qwen2_vl'
23
 
@@ -52,7 +56,7 @@ class AgentQwen2VLModel(Qwen2VLModel):
52
 
53
  def __init__(self, config):
54
  super().__init__(config)
55
- self.norm.register_forward_pre_hook(lambda module, args: setattr(module, 'state', args[0]))
56
 
57
  def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
58
  # ensure gradient tracking (in case that embed_tokens has been frozen)
 
18
  from .loss import BundleLoss
19
 
20
 
21
+ def cache_state_hook(module, args):
22
+ module.state = args[0]
23
+
24
+
25
  class AgentQwen2VLConfig(Qwen2VLConfig):
26
  model_type = 'agent_qwen2_vl'
27
 
 
56
 
57
  def __init__(self, config):
58
  super().__init__(config)
59
+ self.norm.register_forward_pre_hook(cache_state_hook)
60
 
61
  def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
62
  # ensure gradient tracking (in case that embed_tokens has been frozen)