check loading of UniRep parameters
Browse files
    	
        app.py
    CHANGED
    
    | @@ -89,10 +89,9 @@ def display_dti(): | |
| 89 | 
             
                            encoder = SeqVecEmbedder()
         | 
| 90 | 
             
                            embedding = encoder([sequence])
         | 
| 91 | 
             
                            embedding = encoder.reduce_per_protein(embedding)
         | 
| 92 | 
            -
                            st.write(f'SeqVec embedding: {embedding}')
         | 
| 93 | 
             
                        elif selected_encoder == 'UniRep':
         | 
| 94 | 
            -
                             | 
| 95 | 
            -
                             | 
| 96 | 
             
                            from jax_unirep.featurize import get_reps                               
         | 
| 97 | 
             
                            embedding, h_final, c_final = get_reps([sequence])
         | 
| 98 | 
             
                            embedding = embedding.mean(axis=0)
         | 
| @@ -108,6 +107,9 @@ def display_dti(): | |
| 108 | 
             
                            embedding = encoder.reduce_per_protein(embedding)
         | 
| 109 | 
             
                        else: 
         | 
| 110 | 
             
                            st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
         | 
|  | |
|  | |
|  | |
| 111 |  | 
| 112 |  | 
| 113 | 
             
            def display_protein():
         | 
|  | |
| 89 | 
             
                            encoder = SeqVecEmbedder()
         | 
| 90 | 
             
                            embedding = encoder([sequence])
         | 
| 91 | 
             
                            embedding = encoder.reduce_per_protein(embedding)
         | 
|  | |
| 92 | 
             
                        elif selected_encoder == 'UniRep':
         | 
| 93 | 
            +
                            from jax_unirep.utils import load_params
         | 
| 94 | 
            +
                            params = load_params() 
         | 
| 95 | 
             
                            from jax_unirep.featurize import get_reps                               
         | 
| 96 | 
             
                            embedding, h_final, c_final = get_reps([sequence])
         | 
| 97 | 
             
                            embedding = embedding.mean(axis=0)
         | 
|  | |
| 107 | 
             
                            embedding = encoder.reduce_per_protein(embedding)
         | 
| 108 | 
             
                        else: 
         | 
| 109 | 
             
                            st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
         | 
| 110 | 
            +
                            embeddning = None
         | 
| 111 | 
            +
                        if embedding:
         | 
| 112 | 
            +
                            st.write(f'{selected_encoder} embedding: {embedding}')
         | 
| 113 |  | 
| 114 |  | 
| 115 | 
             
            def display_protein():
         | 
