Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	- modules.py +6 -4
    	
        modules.py
    CHANGED
    
    | @@ -31,8 +31,9 @@ class ClassEmbedder(nn.Module): | |
| 31 |  | 
| 32 | 
             
            class TransformerEmbedder(AbstractEncoder):
         | 
| 33 | 
             
                """Some transformer encoder layers"""
         | 
| 34 | 
            -
                def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=" | 
| 35 | 
             
                    super().__init__()
         | 
|  | |
| 36 | 
             
                    self.device = device
         | 
| 37 | 
             
                    self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
         | 
| 38 | 
             
                                                          attn_layers=Encoder(dim=n_embed, depth=n_layer))
         | 
| @@ -48,10 +49,11 @@ class TransformerEmbedder(AbstractEncoder): | |
| 48 |  | 
| 49 | 
             
            class BERTTokenizer(AbstractEncoder):
         | 
| 50 | 
             
                """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
         | 
| 51 | 
            -
                def __init__(self, device=" | 
| 52 | 
             
                    super().__init__()
         | 
| 53 | 
             
                    from transformers import BertTokenizerFast  # TODO: add to reuquirements
         | 
| 54 | 
             
                    self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
         | 
|  | |
| 55 | 
             
                    self.device = device
         | 
| 56 | 
             
                    self.vq_interface = vq_interface
         | 
| 57 | 
             
                    self.max_length = max_length
         | 
| @@ -76,7 +78,7 @@ class BERTTokenizer(AbstractEncoder): | |
| 76 | 
             
            class BERTEmbedder(AbstractEncoder):
         | 
| 77 | 
             
                """Uses the BERT tokenizr model and add some transformer encoder layers"""
         | 
| 78 | 
             
                def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
         | 
| 79 | 
            -
                             device=" | 
| 80 | 
             
                    super().__init__()
         | 
| 81 | 
             
                    self.use_tknz_fn = use_tokenizer
         | 
| 82 | 
             
                    if self.use_tknz_fn:
         | 
| @@ -88,7 +90,7 @@ class BERTEmbedder(AbstractEncoder): | |
| 88 |  | 
| 89 | 
             
                def forward(self, text):
         | 
| 90 | 
             
                    if self.use_tknz_fn:
         | 
| 91 | 
            -
                        tokens = self.tknz_fn(text)#.to(self.device)
         | 
| 92 | 
             
                    else:
         | 
| 93 | 
             
                        tokens = text
         | 
| 94 | 
             
                    z = self.transformer(tokens, return_embeddings=True)
         | 
|  | |
| 31 |  | 
| 32 | 
             
            class TransformerEmbedder(AbstractEncoder):
         | 
| 33 | 
             
                """Some transformer encoder layers"""
         | 
| 34 | 
            +
                def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cpu"):
         | 
| 35 | 
             
                    super().__init__()
         | 
| 36 | 
            +
                    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 37 | 
             
                    self.device = device
         | 
| 38 | 
             
                    self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
         | 
| 39 | 
             
                                                          attn_layers=Encoder(dim=n_embed, depth=n_layer))
         | 
|  | |
| 49 |  | 
| 50 | 
             
            class BERTTokenizer(AbstractEncoder):
         | 
| 51 | 
             
                """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
         | 
| 52 | 
            +
                def __init__(self, device="cpu", vq_interface=True, max_length=77):
         | 
| 53 | 
             
                    super().__init__()
         | 
| 54 | 
             
                    from transformers import BertTokenizerFast  # TODO: add to reuquirements
         | 
| 55 | 
             
                    self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
         | 
| 56 | 
            +
                    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 57 | 
             
                    self.device = device
         | 
| 58 | 
             
                    self.vq_interface = vq_interface
         | 
| 59 | 
             
                    self.max_length = max_length
         | 
|  | |
| 78 | 
             
            class BERTEmbedder(AbstractEncoder):
         | 
| 79 | 
             
                """Uses the BERT tokenizr model and add some transformer encoder layers"""
         | 
| 80 | 
             
                def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
         | 
| 81 | 
            +
                             device="cpu",use_tokenizer=True, embedding_dropout=0.0):
         | 
| 82 | 
             
                    super().__init__()
         | 
| 83 | 
             
                    self.use_tknz_fn = use_tokenizer
         | 
| 84 | 
             
                    if self.use_tknz_fn:
         | 
|  | |
| 90 |  | 
| 91 | 
             
                def forward(self, text):
         | 
| 92 | 
             
                    if self.use_tknz_fn:
         | 
| 93 | 
            +
                        tokens = self.tknz_fn(text) #.to(self.device)
         | 
| 94 | 
             
                    else:
         | 
| 95 | 
             
                        tokens = text
         | 
| 96 | 
             
                    z = self.transformer(tokens, return_embeddings=True)
         | 
