Spaces:
Sleeping
Sleeping
"""Matplotlib based plotting of quantum circuits. | |
Todo: | |
* Optimize printing of large circuits. | |
* Get this to work with single gates. | |
* Do a better job checking the form of circuits to make sure it is a Mul of | |
Gates. | |
* Get multi-target gates plotting. | |
* Get initial and final states to plot. | |
* Get measurements to plot. Might need to rethink measurement as a gate | |
issue. | |
* Get scale and figsize to be handled in a better way. | |
* Write some tests/examples! | |
""" | |
from __future__ import annotations | |
from sympy.core.mul import Mul | |
from sympy.external import import_module | |
from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS | |
__all__ = [ | |
'CircuitPlot', | |
'circuit_plot', | |
'labeller', | |
'Mz', | |
'Mx', | |
'CreateOneQubitGate', | |
'CreateCGate', | |
] | |
np = import_module('numpy') | |
matplotlib = import_module( | |
'matplotlib', import_kwargs={'fromlist': ['pyplot']}, | |
catch=(RuntimeError,)) # This is raised in environments that have no display. | |
if np and matplotlib: | |
pyplot = matplotlib.pyplot | |
Line2D = matplotlib.lines.Line2D | |
Circle = matplotlib.patches.Circle | |
#from matplotlib import rc | |
#rc('text',usetex=True) | |
class CircuitPlot: | |
"""A class for managing a circuit plot.""" | |
scale = 1.0 | |
fontsize = 20.0 | |
linewidth = 1.0 | |
control_radius = 0.05 | |
not_radius = 0.15 | |
swap_delta = 0.05 | |
labels: list[str] = [] | |
inits: dict[str, str] = {} | |
label_buffer = 0.5 | |
def __init__(self, c, nqubits, **kwargs): | |
if not np or not matplotlib: | |
raise ImportError('numpy or matplotlib not available.') | |
self.circuit = c | |
self.ngates = len(self.circuit.args) | |
self.nqubits = nqubits | |
self.update(kwargs) | |
self._create_grid() | |
self._create_figure() | |
self._plot_wires() | |
self._plot_gates() | |
self._finish() | |
def update(self, kwargs): | |
"""Load the kwargs into the instance dict.""" | |
self.__dict__.update(kwargs) | |
def _create_grid(self): | |
"""Create the grid of wires.""" | |
scale = self.scale | |
wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float) | |
gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float) | |
self._wire_grid = wire_grid | |
self._gate_grid = gate_grid | |
def _create_figure(self): | |
"""Create the main matplotlib figure.""" | |
self._figure = pyplot.figure( | |
figsize=(self.ngates*self.scale, self.nqubits*self.scale), | |
facecolor='w', | |
edgecolor='w' | |
) | |
ax = self._figure.add_subplot( | |
1, 1, 1, | |
frameon=True | |
) | |
ax.set_axis_off() | |
offset = 0.5*self.scale | |
ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset) | |
ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset) | |
ax.set_aspect('equal') | |
self._axes = ax | |
def _plot_wires(self): | |
"""Plot the wires of the circuit diagram.""" | |
xstart = self._gate_grid[0] | |
xstop = self._gate_grid[-1] | |
xdata = (xstart - self.scale, xstop + self.scale) | |
for i in range(self.nqubits): | |
ydata = (self._wire_grid[i], self._wire_grid[i]) | |
line = Line2D( | |
xdata, ydata, | |
color='k', | |
lw=self.linewidth | |
) | |
self._axes.add_line(line) | |
if self.labels: | |
init_label_buffer = 0 | |
if self.inits.get(self.labels[i]): init_label_buffer = 0.25 | |
self._axes.text( | |
xdata[0]-self.label_buffer-init_label_buffer,ydata[0], | |
render_label(self.labels[i],self.inits), | |
size=self.fontsize, | |
color='k',ha='center',va='center') | |
self._plot_measured_wires() | |
def _plot_measured_wires(self): | |
ismeasured = self._measurements() | |
xstop = self._gate_grid[-1] | |
dy = 0.04 # amount to shift wires when doubled | |
# Plot doubled wires after they are measured | |
for im in ismeasured: | |
xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale) | |
ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy) | |
line = Line2D( | |
xdata, ydata, | |
color='k', | |
lw=self.linewidth | |
) | |
self._axes.add_line(line) | |
# Also double any controlled lines off these wires | |
for i,g in enumerate(self._gates()): | |
if isinstance(g, (CGate, CGateS)): | |
wires = g.controls + g.targets | |
for wire in wires: | |
if wire in ismeasured and \ | |
self._gate_grid[i] > self._gate_grid[ismeasured[wire]]: | |
ydata = min(wires), max(wires) | |
xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy | |
line = Line2D( | |
xdata, ydata, | |
color='k', | |
lw=self.linewidth | |
) | |
self._axes.add_line(line) | |
def _gates(self): | |
"""Create a list of all gates in the circuit plot.""" | |
gates = [] | |
if isinstance(self.circuit, Mul): | |
for g in reversed(self.circuit.args): | |
if isinstance(g, Gate): | |
gates.append(g) | |
elif isinstance(self.circuit, Gate): | |
gates.append(self.circuit) | |
return gates | |
def _plot_gates(self): | |
"""Iterate through the gates and plot each of them.""" | |
for i, gate in enumerate(self._gates()): | |
gate.plot_gate(self, i) | |
def _measurements(self): | |
"""Return a dict ``{i:j}`` where i is the index of the wire that has | |
been measured, and j is the gate where the wire is measured. | |
""" | |
ismeasured = {} | |
for i,g in enumerate(self._gates()): | |
if getattr(g,'measurement',False): | |
for target in g.targets: | |
if target in ismeasured: | |
if ismeasured[target] > i: | |
ismeasured[target] = i | |
else: | |
ismeasured[target] = i | |
return ismeasured | |
def _finish(self): | |
# Disable clipping to make panning work well for large circuits. | |
for o in self._figure.findobj(): | |
o.set_clip_on(False) | |
def one_qubit_box(self, t, gate_idx, wire_idx): | |
"""Draw a box for a single qubit gate.""" | |
x = self._gate_grid[gate_idx] | |
y = self._wire_grid[wire_idx] | |
self._axes.text( | |
x, y, t, | |
color='k', | |
ha='center', | |
va='center', | |
bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth}, | |
size=self.fontsize | |
) | |
def two_qubit_box(self, t, gate_idx, wire_idx): | |
"""Draw a box for a two qubit gate. Does not work yet. | |
""" | |
# x = self._gate_grid[gate_idx] | |
# y = self._wire_grid[wire_idx]+0.5 | |
print(self._gate_grid) | |
print(self._wire_grid) | |
# unused: | |
# obj = self._axes.text( | |
# x, y, t, | |
# color='k', | |
# ha='center', | |
# va='center', | |
# bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth), | |
# size=self.fontsize | |
# ) | |
def control_line(self, gate_idx, min_wire, max_wire): | |
"""Draw a vertical control line.""" | |
xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx]) | |
ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire]) | |
line = Line2D( | |
xdata, ydata, | |
color='k', | |
lw=self.linewidth | |
) | |
self._axes.add_line(line) | |
def control_point(self, gate_idx, wire_idx): | |
"""Draw a control point.""" | |
x = self._gate_grid[gate_idx] | |
y = self._wire_grid[wire_idx] | |
radius = self.control_radius | |
c = Circle( | |
(x, y), | |
radius*self.scale, | |
ec='k', | |
fc='k', | |
fill=True, | |
lw=self.linewidth | |
) | |
self._axes.add_patch(c) | |
def not_point(self, gate_idx, wire_idx): | |
"""Draw a NOT gates as the circle with plus in the middle.""" | |
x = self._gate_grid[gate_idx] | |
y = self._wire_grid[wire_idx] | |
radius = self.not_radius | |
c = Circle( | |
(x, y), | |
radius, | |
ec='k', | |
fc='w', | |
fill=False, | |
lw=self.linewidth | |
) | |
self._axes.add_patch(c) | |
l = Line2D( | |
(x, x), (y - radius, y + radius), | |
color='k', | |
lw=self.linewidth | |
) | |
self._axes.add_line(l) | |
def swap_point(self, gate_idx, wire_idx): | |
"""Draw a swap point as a cross.""" | |
x = self._gate_grid[gate_idx] | |
y = self._wire_grid[wire_idx] | |
d = self.swap_delta | |
l1 = Line2D( | |
(x - d, x + d), | |
(y - d, y + d), | |
color='k', | |
lw=self.linewidth | |
) | |
l2 = Line2D( | |
(x - d, x + d), | |
(y + d, y - d), | |
color='k', | |
lw=self.linewidth | |
) | |
self._axes.add_line(l1) | |
self._axes.add_line(l2) | |
def circuit_plot(c, nqubits, **kwargs): | |
"""Draw the circuit diagram for the circuit with nqubits. | |
Parameters | |
========== | |
c : circuit | |
The circuit to plot. Should be a product of Gate instances. | |
nqubits : int | |
The number of qubits to include in the circuit. Must be at least | |
as big as the largest ``min_qubits`` of the gates. | |
""" | |
return CircuitPlot(c, nqubits, **kwargs) | |
def render_label(label, inits={}): | |
"""Slightly more flexible way to render labels. | |
>>> from sympy.physics.quantum.circuitplot import render_label | |
>>> render_label('q0') | |
'$\\\\left|q0\\\\right\\\\rangle$' | |
>>> render_label('q0', {'q0':'0'}) | |
'$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$' | |
""" | |
init = inits.get(label) | |
if init: | |
return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init) | |
return r'$\left|%s\right\rangle$' % label | |
def labeller(n, symbol='q'): | |
"""Autogenerate labels for wires of quantum circuits. | |
Parameters | |
========== | |
n : int | |
number of qubits in the circuit. | |
symbol : string | |
A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc. | |
>>> from sympy.physics.quantum.circuitplot import labeller | |
>>> labeller(2) | |
['q_1', 'q_0'] | |
>>> labeller(3,'j') | |
['j_2', 'j_1', 'j_0'] | |
""" | |
return ['%s_%d' % (symbol,n-i-1) for i in range(n)] | |
class Mz(OneQubitGate): | |
"""Mock-up of a z measurement gate. | |
This is in circuitplot rather than gate.py because it's not a real | |
gate, it just draws one. | |
""" | |
measurement = True | |
gate_name='Mz' | |
gate_name_latex='M_z' | |
class Mx(OneQubitGate): | |
"""Mock-up of an x measurement gate. | |
This is in circuitplot rather than gate.py because it's not a real | |
gate, it just draws one. | |
""" | |
measurement = True | |
gate_name='Mx' | |
gate_name_latex='M_x' | |
class CreateOneQubitGate(type): | |
def __new__(mcl, name, latexname=None): | |
if not latexname: | |
latexname = name | |
return type(name + "Gate", (OneQubitGate,), | |
{'gate_name': name, 'gate_name_latex': latexname}) | |
def CreateCGate(name, latexname=None): | |
"""Use a lexical closure to make a controlled gate. | |
""" | |
if not latexname: | |
latexname = name | |
onequbitgate = CreateOneQubitGate(name, latexname) | |
def ControlledGate(ctrls,target): | |
return CGate(tuple(ctrls),onequbitgate(target)) | |
return ControlledGate | |