import numpy as np import matplotlib.pyplot as plt import gradio as gr def get_initial_distribution(seed=42): np.random.seed(seed) # For reproducibility token_probs = np.random.rand(10) token_probs /= np.sum(token_probs) # Normalize to sum to 1 return token_probs def adjust_distribution(temperature, top_k, top_p, initial_probs): # Apply temperature scaling token_probs = np.exp(np.log(initial_probs) / temperature) token_probs /= np.sum(token_probs) # Apply Top-K filtering if top_k > 0: top_k_indices = np.argsort(token_probs)[-top_k:] top_k_probs = np.zeros_like(token_probs) top_k_probs[top_k_indices] = token_probs[top_k_indices] top_k_probs /= np.sum(top_k_probs) # Normalize after filtering token_probs = top_k_probs # Apply top_p (nucleus) filtering if top_p < 1.0: # Sort probabilities in descending order and compute cumulative sum sorted_indices = np.argsort(token_probs)[::-1] cumulative_probs = np.cumsum(token_probs[sorted_indices]) # Find the cutoff index for nucleus sampling cutoff_index = np.searchsorted(cumulative_probs, top_p) + 1 # Get the indices that meet the threshold top_p_indices = sorted_indices[:cutoff_index] top_p_probs = np.zeros_like(token_probs) top_p_probs[top_p_indices] = token_probs[top_p_indices] top_p_probs /= np.sum(top_p_probs) # Normalize after filtering token_probs = top_p_probs # Plotting the probabilities plt.figure(figsize=(10, 6)) plt.bar(range(10), token_probs, tick_label=[f'Token {i}' for i in range(10)]) plt.xlabel('Tokens') plt.ylabel('Probabilities') plt.title('Token Probability Distribution') plt.ylim(0, 1) plt.grid(True) plt.tight_layout() return plt initial_probs = get_initial_distribution() def update_plot(temperature, top_k, top_p): return adjust_distribution(temperature, top_k, top_p, initial_probs) interface = gr.Interface( fn=update_plot, inputs=[ gr.Slider(0.1, 2.0, step=0.1, value=1.0, label="Temperature"), gr.Slider(0, 10, step=1, value=5, label="Top-k"), gr.Slider(0.0, 1.0, step=0.01, value=0.9, label="Top-p"), ], outputs=gr.Plot(label="Token Probability Distribution"), live=True ) interface.launch()