inechita's picture
Upload 2 files
dd1edb0 verified
raw
history blame
2.65 kB
import numpy as np
import matplotlib.pyplot as plt
def generate_random_complex_gaussian_matrix(n):
matrix = np.empty((n, n, n), dtype=np.complex128)
for i in range(n):
for j in range(n):
real_part = np.random.normal(size=n)
imag_part = np.random.normal(size=n)
matrix[i, j] = (real_part + 1j * imag_part) / np.sqrt(2)
return matrix
def orthonormalize_vectors_svd(vectors):
u, _, v = np.linalg.svd(vectors)
return u @ v
def error_QPM(u):
err = -1;
for i in range(u.shape[0]):
slice = u[i, :, :]
err = max(err, np.linalg.norm(slice @ slice.conj().T - np.eye(slice.shape[0])))
slice = u[:, i, :]
err = max(err, np.linalg.norm(slice @ slice.conj().T - np.eye(slice.shape[0])))
return err
def random_quantum_permutation_matrix(n, error_tolerance=1e-6, max_iter=1000):
u = generate_random_complex_gaussian_matrix(n)
iter = 0
error = error_QPM(u)
errors = [error] # Initialize a list to store errors
while error > error_tolerance and iter < max_iter:
# orthonormalize rows
for i in range(n):
u[i, :, :] = orthonormalize_vectors_svd(u[i, :, :])
# orthonormalize columns
for j in range(n):
u[:, j, :] = orthonormalize_vectors_svd(u[:, j, :])
error = error_QPM(u)
errors.append(error) # Append the current error to the list
iter += 1
return u, errors # Return both the matrix and the list of errors
# Example usage
u, errors = random_quantum_permutation_matrix(6)
# Create a figure with two subplots side by side
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
# Plot the errors on a log scale in the first subplot
axs[0].semilogy(errors)
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Error (log scale)')
axs[0].set_title('Error Convergence')
axs[0].grid(True)
# Calculate the absolute values squared of the scalar products of the vectors in the matrix u
scalar_products = []
for i in range(u.shape[0]):
for j in range(1, u.shape[0]):
for k in range(1, u.shape[0]):
for l in range(1, u.shape[0]):
scalar_product = np.abs(np.dot(u[i,j,:], u[k,l,:].conj()))**2
scalar_products.append(scalar_product)
# Plot the histogram in the second subplot
axs[1].hist(scalar_products, bins=30, edgecolor='black')
axs[1].set_xlabel('Absolute Values Squared of Scalar Products')
axs[1].set_ylabel('Frequency')
axs[1].set_title('Histogram of Absolute Values Squared of Scalar Products')
axs[1].grid(True)
# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()