Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						3b4e474
	
1
								Parent(s):
							
							d9f47ef
								
Change input metrics to be on val set. Add input logging
Browse files- remfx/models.py +15 -6
    	
        remfx/models.py
    CHANGED
    
    | @@ -88,7 +88,21 @@ class RemFXModel(pl.LightningModule): | |
| 88 |  | 
| 89 | 
             
                def on_train_batch_start(self, batch, batch_idx):
         | 
| 90 | 
             
                    if self.log_first:
         | 
| 91 | 
            -
                        x,  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 92 | 
             
                        for metric in self.metrics:
         | 
| 93 | 
             
                            # SISDR returns negative values, so negate them
         | 
| 94 | 
             
                            if metric == "SISDR":
         | 
| @@ -106,12 +120,7 @@ class RemFXModel(pl.LightningModule): | |
| 106 | 
             
                            )
         | 
| 107 | 
             
                        self.log_first = False
         | 
| 108 |  | 
| 109 | 
            -
                def on_validation_epoch_start(self):
         | 
| 110 | 
            -
                    self.log_next = True
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
         | 
| 113 | 
             
                    if self.log_next:
         | 
| 114 | 
            -
                        x, target, label = batch
         | 
| 115 | 
             
                        self.model.eval()
         | 
| 116 | 
             
                        with torch.no_grad():
         | 
| 117 | 
             
                            y = self.model.sample(x)
         | 
|  | |
| 88 |  | 
| 89 | 
             
                def on_train_batch_start(self, batch, batch_idx):
         | 
| 90 | 
             
                    if self.log_first:
         | 
| 91 | 
            +
                        x, y, label = batch
         | 
| 92 | 
            +
                        log_wandb_audio_batch(
         | 
| 93 | 
            +
                            logger=self.logger,
         | 
| 94 | 
            +
                            id="input_target",
         | 
| 95 | 
            +
                            samples=x.cpu(),
         | 
| 96 | 
            +
                            sampling_rate=self.sample_rate,
         | 
| 97 | 
            +
                            caption="Training Data",
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def on_validation_epoch_start(self):
         | 
| 101 | 
            +
                    self.log_next = True
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
         | 
| 104 | 
            +
                    x, target, label = batch
         | 
| 105 | 
            +
                    if self.log_first:
         | 
| 106 | 
             
                        for metric in self.metrics:
         | 
| 107 | 
             
                            # SISDR returns negative values, so negate them
         | 
| 108 | 
             
                            if metric == "SISDR":
         | 
|  | |
| 120 | 
             
                            )
         | 
| 121 | 
             
                        self.log_first = False
         | 
| 122 |  | 
|  | |
|  | |
|  | |
|  | |
| 123 | 
             
                    if self.log_next:
         | 
|  | |
| 124 | 
             
                        self.model.eval()
         | 
| 125 | 
             
                        with torch.no_grad():
         | 
| 126 | 
             
                            y = self.model.sample(x)
         |