Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Li
		
	commited on
		
		
					Commit 
							
							·
						
						c941456
	
1
								Parent(s):
							
							631dd30
								
update app.py
Browse files
    	
        open_flamingo/open_flamingo/src/flamingo_lm.py
    CHANGED
    
    | @@ -92,7 +92,9 @@ class FlamingoLayer(nn.Module): | |
| 92 | 
             
                        elif not self.training:
         | 
| 93 | 
             
                            if self.add_visual_token:
         | 
| 94 | 
             
                                if self.input_ids is None:
         | 
|  | |
| 95 | 
             
                                    self.input_ids = decoder_layer_kwargs["input_ids"]
         | 
|  | |
| 96 | 
             
                                else:
         | 
| 97 | 
             
                                    self.input_ids = torch.cat([self.input_ids, decoder_layer_kwargs["input_ids"]], dim=-1)
         | 
| 98 | 
             
                                visual_token_position = (self.input_ids[..., -1] == self.visual_token_id).nonzero().reshape(-1)
         | 
|  | |
| 92 | 
             
                        elif not self.training:
         | 
| 93 | 
             
                            if self.add_visual_token:
         | 
| 94 | 
             
                                if self.input_ids is None:
         | 
| 95 | 
            +
                                    print(decoder_layer_kwargs)
         | 
| 96 | 
             
                                    self.input_ids = decoder_layer_kwargs["input_ids"]
         | 
| 97 | 
            +
             | 
| 98 | 
             
                                else:
         | 
| 99 | 
             
                                    self.input_ids = torch.cat([self.input_ids, decoder_layer_kwargs["input_ids"]], dim=-1)
         | 
| 100 | 
             
                                visual_token_position = (self.input_ids[..., -1] == self.visual_token_id).nonzero().reshape(-1)
         |