Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update loss.py
Browse files
    	
        loss.py
    CHANGED
    
    | @@ -34,7 +34,7 @@ def discriminator_loss(generator, discriminator, mol_graph, adj, annot, batch_si | |
| 34 | 
             
                return node, edge,d_loss
         | 
| 35 |  | 
| 36 |  | 
| 37 | 
            -
            def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel):
         | 
| 38 |  | 
| 39 | 
             
                # Compute loss with fake molecules.
         | 
| 40 |  | 
| @@ -53,7 +53,7 @@ def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, | |
| 53 | 
             
                g_edges_hat_sample = torch.max(edge_sample, -1)[1] 
         | 
| 54 | 
             
                g_nodes_hat_sample = torch.max(node_sample , -1)[1]   
         | 
| 55 |  | 
| 56 | 
            -
                fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) 
         | 
| 57 | 
             
                        for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]        
         | 
| 58 | 
             
                g_loss = prediction_fake
         | 
| 59 | 
             
                # Compute penalty loss.
         | 
| @@ -116,7 +116,7 @@ def discriminator2_loss(generator, discriminator, mol_graph, adj, annot, batch_s | |
| 116 |  | 
| 117 | 
             
                return d2_loss
         | 
| 118 |  | 
| 119 | 
            -
            def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel):
         | 
| 120 |  | 
| 121 | 
             
                # Generate molecules.
         | 
| 122 |  | 
| @@ -140,7 +140,7 @@ def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty | |
| 140 | 
             
                g2_loss_fake = - torch.mean(g_tra_logits_fake2)                                                            
         | 
| 141 |  | 
| 142 | 
             
                # Reward
         | 
| 143 | 
            -
                fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) 
         | 
| 144 | 
             
                            for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]       
         | 
| 145 | 
             
                g2_loss =  g2_loss_fake    
         | 
| 146 | 
             
                if submodel == "RL":
         | 
|  | |
| 34 | 
             
                return node, edge,d_loss
         | 
| 35 |  | 
| 36 |  | 
| 37 | 
            +
            def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel, dataset_name):
         | 
| 38 |  | 
| 39 | 
             
                # Compute loss with fake molecules.
         | 
| 40 |  | 
|  | |
| 53 | 
             
                g_edges_hat_sample = torch.max(edge_sample, -1)[1] 
         | 
| 54 | 
             
                g_nodes_hat_sample = torch.max(node_sample , -1)[1]   
         | 
| 55 |  | 
| 56 | 
            +
                fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name) 
         | 
| 57 | 
             
                        for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]        
         | 
| 58 | 
             
                g_loss = prediction_fake
         | 
| 59 | 
             
                # Compute penalty loss.
         | 
|  | |
| 116 |  | 
| 117 | 
             
                return d2_loss
         | 
| 118 |  | 
| 119 | 
            +
            def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel, drugs_name):
         | 
| 120 |  | 
| 121 | 
             
                # Generate molecules.
         | 
| 122 |  | 
|  | |
| 140 | 
             
                g2_loss_fake = - torch.mean(g_tra_logits_fake2)                                                            
         | 
| 141 |  | 
| 142 | 
             
                # Reward
         | 
| 143 | 
            +
                fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=drugs_name) 
         | 
| 144 | 
             
                            for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]       
         | 
| 145 | 
             
                g2_loss =  g2_loss_fake    
         | 
| 146 | 
             
                if submodel == "RL":
         | 

