codewithdark commited on
Commit
68e98e4
·
verified ·
1 Parent(s): 9cde2c9

Update modeling_latent_recurrent_depth.py

Browse files
Files changed (1) hide show
  1. modeling_latent_recurrent_depth.py +75 -75
modeling_latent_recurrent_depth.py CHANGED
@@ -1,75 +1,75 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Optional, Tuple
5
- import math
6
- from transformers import PretrainedConfig, PreTrainedModel
7
- from Model.latent_Recurrent import LatentRecurrentDepthLM
8
-
9
- # Configuration for the Latent Recurrent Depth Model
10
- class LatentRecurrentDepthConfig(PretrainedConfig):
11
- model_type = "latent_recurrent_depth"
12
-
13
- def __init__(self, vocab_size=50257, d_model=768, num_heads=12, dropout=0.1, **kwargs):
14
- super().__init__(**kwargs)
15
- self.vocab_size = vocab_size
16
- self.d_model = d_model
17
- self.num_heads = num_heads
18
- self.dropout = dropout
19
-
20
-
21
- # Hugging Face-Compatible Model Wrapper
22
- class LatentRecurrentDepthModel(PreTrainedModel):
23
- config_class = LatentRecurrentDepthConfig
24
- base_model_prefix = "latent_recurrent_depth"
25
-
26
- def __init__(self, config: LatentRecurrentDepthConfig):
27
- super().__init__(config)
28
- self.latent_model = LatentRecurrentDepthLM(config.vocab_size, config.d_model, config.num_heads, config.dropout)
29
- self.init_weights()
30
-
31
- def forward(self, input_ids: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
32
- return self.latent_model(input_ids, num_iterations, mask)
33
-
34
- def generate(
35
- self,
36
- input_ids: torch.Tensor,
37
- max_length: int = 20,
38
- num_iterations: int = 3,
39
- temperature: float = 1.0,
40
- top_k: Optional[int] = 50,
41
- ) -> torch.Tensor:
42
- """
43
- Generate a sequence of tokens given input_ids.
44
-
45
- Args:
46
- input_ids: torch.Tensor of shape (batch, seq_length) containing the prompt.
47
- max_length: The number of tokens to generate.
48
- num_iterations: The number of recurrent iterations to use in each forward pass.
49
- temperature: Temperature for scaling logits.
50
- top_k: If set, only sample from the top k tokens.
51
-
52
- Returns:
53
- generated: torch.Tensor containing the generated sequence.
54
- """
55
- generated = input_ids.clone()
56
- self.eval()
57
- with torch.no_grad():
58
- for _ in range(max_length):
59
- # Get logits from the model for the current sequence.
60
- logits = self.forward(generated, num_iterations=num_iterations)
61
- # Use only the logits for the last token in the sequence.
62
- next_token_logits = logits[:, -1, :] / temperature
63
- if top_k is not None:
64
- # Top-k filtering
65
- top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
66
- probabilities = F.softmax(top_k_logits, dim=-1)
67
- next_token = top_k_indices.gather(-1, torch.multinomial(probabilities, num_samples=1))
68
- else:
69
- probabilities = F.softmax(next_token_logits, dim=-1)
70
- next_token = torch.multinomial(probabilities, num_samples=1)
71
- generated = torch.cat([generated, next_token], dim=1)
72
- # Optionally, break if the EOS token is generated.
73
- if next_token.item() == self.config.eos_token_id:
74
- break
75
- return generated
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ import math
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+ from model.latent_Recurrent import LatentRecurrentDepthLM
8
+
9
+ # Configuration for the Latent Recurrent Depth Model
10
+ class LatentRecurrentDepthConfig(PretrainedConfig):
11
+ model_type = "latent_recurrent_depth"
12
+
13
+ def __init__(self, vocab_size=50257, d_model=768, num_heads=12, dropout=0.1, **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.vocab_size = vocab_size
16
+ self.d_model = d_model
17
+ self.num_heads = num_heads
18
+ self.dropout = dropout
19
+
20
+
21
+ # Hugging Face-Compatible Model Wrapper
22
+ class LatentRecurrentDepthModel(PreTrainedModel):
23
+ config_class = LatentRecurrentDepthConfig
24
+ base_model_prefix = "latent_recurrent_depth"
25
+
26
+ def __init__(self, config: LatentRecurrentDepthConfig):
27
+ super().__init__(config)
28
+ self.latent_model = LatentRecurrentDepthLM(config.vocab_size, config.d_model, config.num_heads, config.dropout)
29
+ self.init_weights()
30
+
31
+ def forward(self, input_ids: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
32
+ return self.latent_model(input_ids, num_iterations, mask)
33
+
34
+ def generate(
35
+ self,
36
+ input_ids: torch.Tensor,
37
+ max_length: int = 20,
38
+ num_iterations: int = 3,
39
+ temperature: float = 1.0,
40
+ top_k: Optional[int] = 50,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Generate a sequence of tokens given input_ids.
44
+
45
+ Args:
46
+ input_ids: torch.Tensor of shape (batch, seq_length) containing the prompt.
47
+ max_length: The number of tokens to generate.
48
+ num_iterations: The number of recurrent iterations to use in each forward pass.
49
+ temperature: Temperature for scaling logits.
50
+ top_k: If set, only sample from the top k tokens.
51
+
52
+ Returns:
53
+ generated: torch.Tensor containing the generated sequence.
54
+ """
55
+ generated = input_ids.clone()
56
+ self.eval()
57
+ with torch.no_grad():
58
+ for _ in range(max_length):
59
+ # Get logits from the model for the current sequence.
60
+ logits = self.forward(generated, num_iterations=num_iterations)
61
+ # Use only the logits for the last token in the sequence.
62
+ next_token_logits = logits[:, -1, :] / temperature
63
+ if top_k is not None:
64
+ # Top-k filtering
65
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
66
+ probabilities = F.softmax(top_k_logits, dim=-1)
67
+ next_token = top_k_indices.gather(-1, torch.multinomial(probabilities, num_samples=1))
68
+ else:
69
+ probabilities = F.softmax(next_token_logits, dim=-1)
70
+ next_token = torch.multinomial(probabilities, num_samples=1)
71
+ generated = torch.cat([generated, next_token], dim=1)
72
+ # Optionally, break if the EOS token is generated.
73
+ if next_token.item() == self.config.eos_token_id:
74
+ break
75
+ return generated