BaggerOfWords commited on
Commit
743ec89
·
1 Parent(s): 2812121

add app file and mosaic file

Browse files
Files changed (2) hide show
  1. gradio_app.py +94 -0
  2. mosaic.py +344 -0
gradio_app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from mosaic import Mosaic # adjust import as needed
3
+
4
+ # Maximum number of model textboxes
5
+ MAX_MODELS = 10
6
+
7
+ def update_textboxes(n_visible):
8
+ """
9
+ Given the current visible count, increments it by 1 (up to MAX_MODELS)
10
+ and returns updated visibility settings for all model textboxes.
11
+ """
12
+ if n_visible < MAX_MODELS:
13
+ n_visible += 1
14
+ # Create a list of update objects for each textbox: visible if its index is less than n_visible.
15
+ updates = []
16
+ for i in range(MAX_MODELS):
17
+ if i < n_visible:
18
+ updates.append(gr.update(visible=True))
19
+ else:
20
+ updates.append(gr.update(visible=False))
21
+ return n_visible, *updates
22
+
23
+ def run_scoring(input_text, model1, model2, model3, model4, model5, model6, model7, model8, model9, model10, threshold_choice, custom_threshold):
24
+ """
25
+ Collect all non-empty model paths, instantiate Mosaic, compute the score,
26
+ and return a message based on the threshold.
27
+ """
28
+ model_paths = []
29
+ for m in [model1, model2, model3, model4, model5, model6, model7, model8, model9, model10]:
30
+ if m.strip() != "":
31
+ model_paths.append(m.strip())
32
+ if len(model_paths) < 2:
33
+ return "Please enter at least two model paths.", None, None
34
+ # Choose threshold value
35
+ if threshold_choice == "default":
36
+ threshold = 0.0
37
+ elif threshold_choice == "raid":
38
+ threshold = 0.23
39
+ elif threshold_choice == "custom":
40
+ threshold = custom_threshold
41
+ else:
42
+ threshold = 0.0
43
+ # Instantiate the Mosaic class with the selected model paths.
44
+ mosaic_instance = Mosaic(model_name_or_paths=model_paths, one_model_mode=False)
45
+ final_score = mosaic_instance.compute_end_score(input_text)
46
+ if final_score < threshold:
47
+ result_message = "This text was probably generated."
48
+ else:
49
+ result_message = "This text is likely human-generated."
50
+ return result_message, final_score, threshold
51
+
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("# MOSAIC Scoring App")
54
+ with gr.Row():
55
+ input_text = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
56
+ with gr.Column():
57
+ gr.Markdown("### Model Paths (at least 2 required)")
58
+ gr.Markdown("Order matters for model 1 only, the Reference model. Please use the one with the best perplexity on human texts. (The largest LLM if applicable.) GPT2 models are enough to detect easy prompts from chatgpt.")
59
+ # State to keep track of the number of visible textboxes (starting with 2)
60
+ n_models_state = gr.State(2)
61
+ # Create 10 textboxes. We'll name them model1, model2, ..., model10.
62
+ model1 = gr.Textbox(value="openai-community/gpt2-large", label="Model 1 Path ", visible=True)
63
+ model2 = gr.Textbox(value="openai-community/gpt2-medium", label="Model 2 Path", visible=True)
64
+ model3 = gr.Textbox(value="", label="Model 3 Path", visible=False)
65
+ model4 = gr.Textbox(value="", label="Model 4 Path", visible=False)
66
+ model5 = gr.Textbox(value="", label="Model 5 Path", visible=False)
67
+ model6 = gr.Textbox(value="", label="Model 6 Path", visible=False)
68
+ model7 = gr.Textbox(value="", label="Model 7 Path", visible=False)
69
+ model8 = gr.Textbox(value="", label="Model 8 Path", visible=False)
70
+ model9 = gr.Textbox(value="", label="Model 9 Path", visible=False)
71
+ model10 = gr.Textbox(value="", label="Model 10 Path", visible=False)
72
+ # Add a plus button to reveal one more textbox.
73
+ plus_button = gr.Button("+", elem_id="plus_button")
74
+ # When plus_button is clicked, update n_models_state and all model textboxes.
75
+ plus_button.click(
76
+ fn=update_textboxes,
77
+ inputs=n_models_state,
78
+ outputs=[n_models_state, model1, model2, model3, model4, model5, model6, model7, model8, model9, model10]
79
+ )
80
+ with gr.Row():
81
+ threshold_choice = gr.Radio(choices=["default", "raid", "custom"], value="default", label="Threshold Choice")
82
+ custom_threshold = gr.Number(value=0.0, label="Custom Threshold (if 'custom' selected)")
83
+ with gr.Row():
84
+ output_message = gr.Textbox(label="Result Message")
85
+ output_score = gr.Number(label="Final Score")
86
+ output_threshold = gr.Number(label="Threshold Used")
87
+ run_button = gr.Button("Run Scoring")
88
+ run_button.click(
89
+ fn=run_scoring,
90
+ inputs=[input_text, model1, model2, model3, model4, model5, model6, model7, model8, model9, model10, threshold_choice, custom_threshold],
91
+ outputs=[output_message, output_score, output_threshold]
92
+ )
93
+
94
+ demo.launch()
mosaic.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import numpy as np
3
+ import torch
4
+ import transformers
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import torch.nn.functional as F
7
+
8
+ torch.set_grad_enabled(False)
9
+
10
+ def apply_top_p_with_epsilon(logits: torch.Tensor, top_p: float, epsilon: float = 1e-10) -> torch.Tensor:
11
+ """
12
+ Applies a top-p (nucleus) filtering to logits but, instead of setting
13
+ the logits of non-selected tokens to -inf (which would result in zero probability),
14
+ sets them to log(epsilon), so that the support remains the same.
15
+
16
+ Parameters:
17
+ logits: Tensor of shape (batch, seq_len, vocab_size)
18
+ top_p: The nucleus threshold (e.g. 0.7, 0.8, etc.)
19
+ epsilon: The small value to assign to tokens not selected.
20
+
21
+ Returns:
22
+ new_logits: Tensor with the same shape as logits.
23
+ """
24
+ # Compute probabilities from logits
25
+ probs = F.softmax(logits, dim=-1)
26
+ # Sort probabilities (descending) along the vocabulary dimension.
27
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
28
+ # Compute the cumulative sum along the sorted probabilities.
29
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
30
+ # Create a mask: True for tokens to keep.
31
+ # We keep tokens until cumulative_probs <= top_p.
32
+ keep_mask = cumulative_probs <= top_p
33
+
34
+ # Ensure that at least one token is kept per example: if none are kept, keep the top one.
35
+ # Here we check along the vocab dimension.
36
+ no_token_kept = keep_mask.sum(dim=-1, keepdim=True) == 0
37
+ if no_token_kept.any():
38
+ # For positions where no token was kept, set the first token (highest probability) to True.
39
+ # Note: torch.scatter_ returns a modified tensor.
40
+ # We create a tensor of zeros (False) and then scatter True into the first column.
41
+ fix_mask = torch.zeros_like(keep_mask, dtype=torch.bool)
42
+ fix_mask.scatter_(-1, torch.zeros_like(keep_mask[..., :1], dtype=torch.long), True)
43
+ keep_mask = torch.where(no_token_kept, fix_mask, keep_mask)
44
+
45
+ # Now, create new logits: copy the original logits.
46
+ new_logits = logits.clone()
47
+ # For tokens that are not kept (i.e. where keep_mask is False), set their logit to log(epsilon)
48
+ new_logits[~keep_mask] = torch.log(torch.tensor(epsilon, device=logits.device, dtype=logits.dtype))
49
+ return new_logits
50
+
51
+ class Mosaic(object):
52
+ def __init__(self,
53
+ model_name_or_paths: List[str],
54
+ use_bfloat16: bool = True,
55
+ max_token_observed: int = 512,
56
+ unigram: Optional[str] = None,
57
+ custom_config : Optional[List[bool]] = None,
58
+ stupid_mode: bool = False,
59
+ one_model_mode: bool = False
60
+ ) -> None:
61
+ self.models = []
62
+ for i, model_name_or_path in enumerate(model_name_or_paths):
63
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
64
+ device_map="auto",
65
+ trust_remote_code=True,
66
+ torch_dtype=torch.bfloat16 if use_bfloat16
67
+ else torch.float32
68
+ )
69
+ model.eval() # Set the model to evaluation mode
70
+ self.models.append(model)
71
+ print(f"Loaded model: {model_name_or_path}")
72
+ # Print the device map
73
+ #print(f"Device map for {model_name_or_path}: {model.hf_device_map}")
74
+
75
+ if stupid_mode:
76
+ self.max_iters = 0
77
+ else:
78
+ self.max_iters = 1000
79
+
80
+ self.one_model_mode = one_model_mode
81
+
82
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_paths[-1])
83
+ if not self.tokenizer.pad_token:
84
+ self.tokenizer.pad_token = self.tokenizer.eos_token
85
+
86
+ self.max_token_observed = max_token_observed
87
+
88
+ self.nb_models = len(self.models)
89
+ self.unigram_path = unigram
90
+
91
+ if custom_config is None:
92
+ custom_config = [False] * self.nb_models
93
+ self.custom_config = custom_config
94
+
95
+ def _tokenize(self, batch: list[str]) -> transformers.BatchEncoding:
96
+ encodings = self.tokenizer(
97
+ batch,
98
+ return_tensors="pt",
99
+ padding="longest",
100
+ truncation=True,
101
+ max_length=self.max_token_observed,
102
+ return_token_type_ids=False)
103
+ return encodings
104
+
105
+ def trim_logits(self, logits, max_length=32000):
106
+ # Check the shape of the logits tensor
107
+ if logits.shape[2] > max_length:
108
+ # Slice the tensor to keep only the first max_length elements along the last dimension
109
+ logits = logits[:, :, :max_length]
110
+ return logits
111
+
112
+ @torch.inference_mode()
113
+ def _get_logits(self, encodings: transformers.BatchEncoding) -> List[torch.Tensor]:
114
+ # If one_model_mode is active, we simulate multiple models by applying top-p with different thresholds.
115
+ if self.one_model_mode:
116
+ # Compute base logits from the single model.
117
+ model = self.models[0]
118
+ device = next(model.parameters()).device
119
+ model_encodings = encodings.to(device)
120
+ base_logits = model(**model_encodings).logits
121
+ # Optionally trim logits:
122
+ # base_logits = self.trim_logits(base_logits)
123
+ # Define the top-p thresholds (e.g., four different values)
124
+ top_p_values = [0.7, 0.8, 0.9, 0.95]
125
+ # Epsilon value for non-selected tokens (you can adjust this if needed)
126
+ epsilon = 1e-10
127
+ logits_list = []
128
+ for top_p in top_p_values:
129
+ warped_logits = apply_top_p_with_epsilon(base_logits, top_p, epsilon)
130
+ logits_list.append(warped_logits)
131
+ else:
132
+ # Normal mode: use each model in self.models.
133
+ logits_list = []
134
+ for i, model in enumerate(self.models):
135
+ device = next(model.parameters()).device
136
+ model_encodings = encodings.to(device)
137
+ logits = model(**model_encodings).logits
138
+ # Optionally trim logits:
139
+ # logits = self.trim_logits(logits)
140
+ logits_list.append(logits)
141
+ if device.type == "cuda":
142
+ torch.cuda.synchronize(device)
143
+
144
+ if self.unigram_path:
145
+ batch_size, seq_len, voc_size = logits_list[0].shape
146
+ unigram_proba = torch.load(self.unigram_path)
147
+ unigram_proba += 1e-10
148
+ unigram_logits = torch.log(unigram_proba)
149
+ # Optionally center logits if needed:
150
+ logits = logits_list[0] - logits_list[0].mean(dim=-1, keepdim=True)
151
+ expanded_unigram_logits = unigram_logits.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, voc_size)
152
+ logits_list.append(expanded_unigram_logits)
153
+ return logits_list
154
+
155
+ def get_softmax_probabilities(self, input_text):
156
+ encodings = self._tokenize(input_text)
157
+ logits_list = self._get_logits(encodings)
158
+ probabilities_list = softmax_probabilities_all_models(logits_list)
159
+ return encodings, logits_list, probabilities_list
160
+
161
+ def compute_arimoto_torch(self, input_text, max_iters=1000):
162
+ encodings, logits_list, tensors_list = self.get_softmax_probabilities(input_text)
163
+ nb_models = len(tensors_list)
164
+ seq_len = len(encodings.input_ids[0])
165
+ voc_size = tensors_list[0].shape[-1]
166
+
167
+ device = tensors_list[0].device
168
+ # Move all tensors in tensors_list to the device of the first tensor
169
+ tensors_list = [tensor.to(device) for tensor in tensors_list]
170
+
171
+ # Stack all model predictions along a new dimension to form a (seq_len, nb_models, voc_size) tensor
172
+ probabilities_tensor = torch.stack([t[0] for t in tensors_list], dim=1).to(tensors_list[0].device)
173
+
174
+ # Run the Blahut-Arimoto algorithm on the entire batch
175
+ capacity, p = blahut_arimoto_torch(probabilities_tensor, max_iters=max_iters)
176
+
177
+ # Prepare the weighted sum tensor, initially zeros
178
+ weighted_sum_tensor = torch.zeros_like(tensors_list[0])
179
+
180
+ # Here, we need an additional mechanism if 'p' shapes or logic require different handling
181
+ # Assuming 'p' is now (seq_len, nb_models), apply weights to each model's output
182
+ for i in range(nb_models):
183
+ weighted_sum_tensor += p[:, i:i+1] * tensors_list[i]
184
+
185
+ return encodings, weighted_sum_tensor, tensors_list, p, logits_list
186
+
187
+ def compute_scores(self, input_text):
188
+ encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text, max_iters=self.max_iters)
189
+ log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
190
+ ppl_list = perplexity_all_models(encodings, logits_list)
191
+ x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
192
+ return log_ppl, x_ppl_list, arimoto_weights, nll, ppl_list
193
+
194
+ def compute_end_score(self, input_text):
195
+ encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text)
196
+ log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
197
+ ppl_list = perplexity_all_models(encodings, logits_list)
198
+ x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
199
+ log_ppl_value = log_ppl.item()
200
+ x_ppl_values = [x.item() for x in x_ppl_list]
201
+ final_score = log_ppl_value - x_ppl_values[0] #Ensure your "reference model" is given as first argument
202
+ return final_score
203
+
204
+ def perplexity(encodings, weighted_sum_tensor):
205
+ shifted_probabilities = weighted_sum_tensor[..., :-1, :].contiguous()
206
+ shifted_labels = encodings.input_ids[..., 1:].contiguous()
207
+ shifted_attention_mask = encodings.attention_mask[..., 1:].contiguous()
208
+
209
+ device = shifted_probabilities.device
210
+
211
+ # Ensure all tensors are moved to the same device
212
+ shifted_probabilities = shifted_probabilities.to(device)
213
+ shifted_labels = shifted_labels.to(device)
214
+ shifted_attention_mask = shifted_attention_mask.to(device)
215
+
216
+ actual_next_token_probabilities = torch.gather(shifted_probabilities, 2, shifted_labels.unsqueeze(-1)).squeeze(-1)
217
+
218
+ nll = -torch.log(actual_next_token_probabilities + 1e-12)
219
+ nll_masked = nll * shifted_attention_mask
220
+
221
+ # Calculate the average NLL per sequence, taking into account only the valid (non-padded) tokens
222
+ average_nll = torch.sum(nll_masked, dim=1) / torch.sum(shifted_attention_mask, dim=1)
223
+
224
+ # Calculate perplexity per sequence
225
+ perplexity = torch.exp(average_nll)
226
+ return average_nll, perplexity, nll_masked
227
+
228
+ def cross_entropy(weighted_sum_tensor, probabilities_list):
229
+ device = weighted_sum_tensor.device
230
+ x_ppl_list = []
231
+
232
+ # Compute log of weighted_sum_tensor outside the loop since it doesn't depend on m2_probabilities
233
+ log_M1 = torch.log(weighted_sum_tensor).to(device)
234
+
235
+ for m2_probabilities in probabilities_list:
236
+ m2_probabilities = m2_probabilities.to(device)
237
+ # Ensure m2_probabilities is correctly shaped for batch matrix multiplication
238
+ # log_M1 shape is already (batch_size, sequence_length, vocabulary_size)
239
+ # We need m2_probabilities in shape (batch_size, vocabulary_size, sequence_length) for bmm
240
+ m2_probabilities_transposed = m2_probabilities.transpose(1, 2)
241
+
242
+ # Perform batch matrix multiplication
243
+ # Resulting shape: (batch_size, sequence_length, sequence_length)
244
+ # We sum over the vocabulary dimension, effectively computing the dot product for each sequence position
245
+ dot_products = torch.bmm(log_M1, m2_probabilities_transposed)
246
+
247
+ # Since we're interested in the diagonal (dot products of corresponding vectors), we extract it
248
+ # The diagonal for each item in the batch gives us the dot products we're interested in
249
+ # torch.diagonal doesn't support batched operations directly, so we need to workaround
250
+ dot_products_diagonal = torch.einsum('bii->bi', dot_products) # Using einsum to extract diagonals for batch
251
+
252
+ # Compute the mean of the dot_products_diagonal across the sequence dimension
253
+ # This gives us the average dot product per sequence, which is then negated
254
+ x_ppl = -torch.mean(dot_products_diagonal, dim=1)
255
+
256
+ x_ppl_list.append(x_ppl)
257
+ x_ppl_tensor = torch.stack(x_ppl_list)
258
+ return x_ppl_list #, x_ppl_tensor
259
+
260
+ def softmax_probabilities_all_models(logits_list: List[torch.Tensor]) -> List[torch.Tensor]:
261
+ """
262
+ Calculates the softmax probabilities for the entire sequence of tokens for each model.
263
+
264
+ Parameters:
265
+ - logits_list: List[torch.Tensor]
266
+ A list containing the logits tensor for each model.
267
+
268
+ Returns:
269
+ - List[torch.Tensor]: A list of tensors, where each tensor is the softmax probabilities
270
+ for one model across the entire sequence of tokens.
271
+ """
272
+ softmax_fn = torch.nn.Softmax(dim=-1)
273
+ probabilities_list = []
274
+
275
+ for logits in logits_list:
276
+ # Calculate softmax probabilities across the vocabulary for each token position
277
+ softmax_probabilities = softmax_fn(logits)
278
+ probabilities_list.append(softmax_probabilities)
279
+
280
+ return probabilities_list
281
+
282
+ def perplexity_logits(encoding, logits):
283
+ # Ensure encoding tensors are moved to the same device as logits
284
+ device = logits.device
285
+ logits = torch.clamp(logits, min=-20, max=50)
286
+
287
+ encoding_input_ids = encoding.input_ids.to(device)
288
+ encoding_attention_mask = encoding.attention_mask.to(device)
289
+
290
+ ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
291
+ shifted_logits = logits[..., :-1, :].contiguous()
292
+ shifted_labels = encoding_input_ids[..., 1:].contiguous()
293
+ shifted_attention_mask = encoding_attention_mask[..., 1:].contiguous()
294
+
295
+ # Calculate Cross-Entropy loss
296
+ cross_entropy_loss = ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels)
297
+ # Apply attention mask
298
+ masked_ce_loss = cross_entropy_loss * shifted_attention_mask
299
+ # Calculate perplexity
300
+ ppl = masked_ce_loss.sum(1) / shifted_attention_mask.sum(1)
301
+ # Move result to CPU and convert to numpy for further processing if needed
302
+ ppl = ppl.to("cpu").float().numpy()
303
+
304
+ return ppl
305
+
306
+ def perplexity_all_models(encoding, logits_list):
307
+ ppl_list = []
308
+ for logits in logits_list:
309
+ ppl = perplexity_logits(encoding, logits)
310
+ ppl_list.append(ppl)
311
+ return ppl_list
312
+
313
+ def blahut_arimoto_torch(W, epsilon=1e-6, max_iters=1000):
314
+ """
315
+ Batch-process Blahut-Arimoto using PyTorch for multiple sequences.
316
+ """
317
+ seq_len, nb_models, voc_size = W.shape
318
+ p = torch.full((seq_len, nb_models), 1.0 / nb_models, device=W.device, dtype=W.dtype)
319
+ prod_exp = torch.ones((seq_len, nb_models), device=W.device, dtype=W.dtype)
320
+
321
+ for _ in range(max_iters):
322
+ # Calculate the marginal probabilities
323
+ sum_p_w = torch.bmm(p.unsqueeze(1), W).squeeze(1) # Resultant shape: (seq_len, voc_size)
324
+
325
+ # Calculate normalized probabilities
326
+ W_normalized = W / sum_p_w.unsqueeze(1) # Broadcasting to shape (seq_len, nb_models, voc_size)
327
+
328
+ # Avoid numerical issues with logarithms
329
+ W_normalized[W_normalized == 0] = torch.finfo(W.dtype).eps
330
+ log_term = torch.log(W_normalized)
331
+ log_term[torch.isnan(log_term) | torch.isinf(log_term)] = 0
332
+
333
+ # Compute product exponentials and update probabilities
334
+ prod_exp = torch.exp(torch.sum(W * log_term, axis=2)) # Sum across voc_size
335
+ p_new = (p * prod_exp) / torch.sum(p * prod_exp, dim=1, keepdim=True)
336
+
337
+ # Check convergence
338
+ if torch.max(torch.abs(p - p_new)) < epsilon:
339
+ break
340
+ p = p_new
341
+
342
+ # Compute channel capacity
343
+ capacity = torch.log(torch.sum(p * prod_exp, dim=1)) / torch.log(torch.tensor(2.0, device=W.device))
344
+ return capacity, p