__doc__ = "visualize.py: Contains functions to visualize circuits and draw diagrams"
__author__ = "Eli Weissler, Mohit Bhat"
__version__ = "0.1.0"
__all__ = ["draw_circuit_diagram", "draw_basegraph", "draw_circuit_graph"]
from pathlib import Path
from typing import Union
import numpy as np
import networkx as nx
import matplotlib.transforms
import matplotlib
import matplotlib.pyplot as plt
import schemdraw.elements
from tqdm import tqdm
from scipy.spatial.distance import pdist
import schemdraw
import schemdraw.elements as elm
from sircuitenum import utils
from sircuitenum import reduction as red
from sircuitenum import visualize as viz
G_POS_BG = {2: [{0: np.array([-1/np.sqrt(2), -1/np.sqrt(2)]),
1: np.array([0., 0.])}],
3: [{0: np.array([0., 0.]),
1: np.array([0.5, np.sqrt(3)/2]),
2: np.array([-0.5, np.sqrt(3)/2])}]*2,
4: [{0: np.array([0., 0.]),
1: np.array([1., 1.])/np.sqrt(2),
2: np.array([1., 0.])/np.sqrt(2),
3: np.array([0., 1.])/np.sqrt(2)}]*6,
5: [{0: np.array([0, 1]),
1: np.array([np.sin(2*2*np.pi/5), np.cos(2*2*np.pi/5)]),
2: np.array([np.sin(4*2*np.pi/5), np.cos(4*2*np.pi/5)]),
3: np.array([np.sin(1*2*np.pi/5), np.cos(1*2*np.pi/5)]),
4: np.array([np.sin(3*2*np.pi/5), np.cos(3*2*np.pi/5)])}]*21}
G_POS = {2: [{0: np.array([-1/np.sqrt(2), -1/np.sqrt(2)]),
1: np.array([0., 0.])}],
3: [{0: np.array([0., 0.]),
1: np.array([0.5, np.sqrt(3)/2]),
2: np.array([-0.5, np.sqrt(3)/2])}]*2,
4: [{0: np.array([np.sqrt(3)/2, -0.5]),
1: np.array([0., 1.]),
2: np.array([-np.sqrt(3)/2, -0.5]),
3: np.array([0., 0.])},
{0: -1*np.array([1/np.sqrt(2), 1/np.sqrt(2)]),
1: np.array([1/np.sqrt(2), 1/np.sqrt(2)]),
2: -2*np.array([1/np.sqrt(2), 1/np.sqrt(2)]),
3: np.array([0., 0.])},
{0: np.array([-0.5, np.sqrt(3)/2]),
1: np.array([0., -1.]),
2: np.array([0.5, np.sqrt(3)/2]),
3: np.array([0., 0.])},
{0: np.array([0., 0.]),
1: np.array([1., -1.]),
2: np.array([0., -1.]),
3: np.array([1., 0.])},
{0: np.array([-0.5, np.sqrt(3)/2]),
1: np.array([1., 0.]),
2: np.array([0.5, np.sqrt(3)/2]),
3: np.array([0., 0.])},
{0: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/5]),
1: (3/np.sqrt(3))*np.array([1., 0.]),
2: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/2]),
3: (3/np.sqrt(3))*np.array([0., 0.])}],
5: [{0: np.array([-1, 0]),
1: np.array([0, 1]),
2: np.array([1, 0]),
3: np.array([0, -1]),
4: np.array([0, 0])},
{0: np.array([0., -1.]),
1: np.array([-0.5, np.sqrt(3)/2]),
2: np.array([0.5, np.sqrt(3)/2]),
3: np.array([0, -2]),
4: np.array([0., 0.])},
{0: np.array([-0.5, -np.sqrt(3)/2]),
1: np.array([-0.5, np.sqrt(3)/2]),
2: np.array([0.5, np.sqrt(3)/2]),
3: np.array([0.5, -np.sqrt(3)/2]),
4: np.array([0., 0.])},
{0: np.array([-0.5, -np.sqrt(3)/2]),
1: np.array([1, 0]),
2: np.array([0, 1]),
3: np.array([0.5, -np.sqrt(3)/2]),
4: np.array([0., 0.])},
{0: np.array([0, -1]),
1: np.array([1, 0]),
2: np.array([0, 1]),
3: np.array([1, -1]),
4: np.array([0., 0.])},
{0: np.array([-0.5, -np.sqrt(3)/2]),
1: np.array([1, 0]),
2: np.array([0.5, np.sqrt(3)/2]),
3: np.array([0.5, -np.sqrt(3)/2]),
4: np.array([0., 0.])},
{0: np.array([0, -1]),
1: np.array([0.5, -0.5]),
2: np.array([1, 0]),
3: np.array([1, -1]),
4: np.array([0., 0.])},
{0: np.array([0, -1]),
1: np.array([0.75, -0.25]),
2: np.array([1, 0]),
3: np.array([1, -1]),
4: np.array([0., 0.])},
{0: np.array([1, 1]),
1: np.array([-1, 1]),
2: np.array([2, 0]),
3: np.array([-2, 0]),
4: np.array([0, 0])},
{0: np.array([0, 0]),
1: np.array([1, 2]),
2: np.array([2, 0]),
3: np.array([1, 3]),
4: np.array([1, 1])},
{0: np.array([0.5, 1-np.sqrt(3)/2]),
1: np.array([1.5, 1+np.sqrt(3)/2]),
2: np.array([1.5, 1-np.sqrt(3)/2]),
3: np.array([0.5, 1+np.sqrt(3)/2]),
4: np.array([1, 1])},
{0: np.array([0, 0]),
1: np.array([1, 1+np.sqrt(2)/2]),
2: np.array([2, 0]),
3: np.array([1, 1-0.6]),
4: np.array([1, 1])},
{0: np.array([0, 0]),
1: np.array([1, 1+np.sqrt(2)/2]),
2: np.array([2, 0]),
3: np.array([1, 1-0.6]),
4: np.array([1, 1])},
{0: np.array([0, 1]),
1: np.array([np.sin(2*2*np.pi/5), np.cos(2*2*np.pi/5)]),
2: np.array([np.sin(4*2*np.pi/5), np.cos(4*2*np.pi/5)]),
3: np.array([np.sin(1*2*np.pi/5), np.cos(1*2*np.pi/5)]),
4: np.array([np.sin(3*2*np.pi/5), np.cos(3*2*np.pi/5)])},
{0: np.array([0, 1]),
1: np.array([np.sin(2*2*np.pi/5), np.cos(2*2*np.pi/5)]),
2: np.array([np.sin(4*2*np.pi/5), np.cos(4*2*np.pi/5)]),
3: np.array([np.sin(1*2*np.pi/5), np.cos(1*2*np.pi/5)]),
4: np.array([np.sin(3*2*np.pi/5), np.cos(3*2*np.pi/5)])},
{0: np.array([0, 0]),
1: np.array([0.5, 1+np.sqrt(3)/2]),
2: np.array([1, 0]),
3: np.array([0, 1]),
4: np.array([1, 1])},
{0: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/5]),
1: (3/np.sqrt(3))*np.array([0, 2*np.sqrt(3)/5]),
2: (3/np.sqrt(3))*np.array([1., 0.]),
3: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/2]),
4: (3/np.sqrt(3))*np.array([0., 0.])},
{0: (3/np.sqrt(3))*np.array([1., 0.]),
1: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/2]),
2: (3/np.sqrt(3))*np.array([0., 0.]),
3: (3/np.sqrt(3))*np.array([1, 2*np.sqrt(3)/5]),
4: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/5])},
{0: (3/np.sqrt(3))*np.array([1., 0.]),
1: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/2]),
2: (3/np.sqrt(3))*np.array([0., 0.]),
3: (3/np.sqrt(3))*np.array([1, 2*np.sqrt(3)/5]),
4: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/5])},
{0: (3/np.sqrt(3))*np.array([1., 0.]),
1: (3/np.sqrt(3))*np.array([0.2, 0.1+np.sqrt(3)/5]),
2: (3/np.sqrt(3))*np.array([-0.25, 0.]),
3: (3/np.sqrt(3))*np.array([0.25, np.sqrt(3)/2]),
4: (3/np.sqrt(3))*np.array([0.5, np.sqrt(3)/5])},
{0: np.array([0, 1]),
1: np.array([np.sin(2*2*np.pi/5), np.cos(2*2*np.pi/5)]),
2: np.array([np.sin(4*2*np.pi/5), np.cos(4*2*np.pi/5)]),
3: np.array([np.sin(1*2*np.pi/5), np.cos(1*2*np.pi/5)]),
4: np.array([np.sin(3*2*np.pi/5), np.cos(3*2*np.pi/5)])}]}
def black_or_white_text(color: tuple):
"""
Determines whether it's more appropriate
to write using black or white text on the
given color.
From: https://stackoverflow.com/questions/3942878/how-to-decide-font-color-in-white-or-black-depending-on-background-color
Args:
color (tuple): (R, G, B) between 0 and 1 each
Returns:
True for black text, False for white text
"""
thresh = color[0]*0.299 + color[1]*0.587 + color[2]*0.114
# 186/255
return thresh > 0.729
[docs]def draw_circuit_graph(circuit: list, edges: list, gtype: str = "component",
out="",
node_size: float = 10000, scale: float = 6,
font_size: int = 30) -> Union[plt.Figure, None]:
"""
Draws the port or component graph representation of a given circuit.
This function visualizes a quantum circuit as either a **port graph** or a **component graph**
using NetworkX and Matplotlib. The generated graph can be displayed interactively
or saved as an image file.
Parameters
----------
circuit : list of list of str
A nested list representing the circuit elements.
Example: `[['C'], ['C'], ['L'], ['C', 'J']]`
edges : list of tuple of int
A list of edge connections specifying how circuit elements are connected.
Example: `[(0,1), (1,2), (2,3), (3,0)]`
gtype : str, optional
Type of graph to draw. Options are:
- `'component'` : Draws a component-level circuit graph.
- `'port'` : Draws a port-level circuit graph.
Default is `'component'`.
out : str, optional
Filename (with extension) to save the plot.
If `""` (empty string), the graph is displayed interactively.
Default is `"circuit_graph.png"`.
node_size : float, optional
Size parameter for nodes in the plotted graph.
scale : float, optional
Scaling factor for overall spacing of nodes.
font_size : int, optional
Font size for node labels.
Returns
-------
matplotlib.figure.Figure or None
- If `out=""`, returns a Matplotlib figure object for interactive viewing.
- Otherwise, saves the figure to `out` and returns `None`.
Examples
--------
>>> draw_circuit_graph(
>>> circuit=[['C'], ['C'], ['L'], ['C', 'J']],
>>> edges=[(0,1), (1,2), (2,3), (3,0)],
>>> gtype="component",
>>> out="circuit.png"
>>> )
"""
# Get the layout and scale it
fig = plt.figure()
fig.set_size_inches(fig.get_size_inches()*scale)
if gtype == "component":
G = red.convert_circuit_to_component_graph(circuit, edges)
else:
G = red.convert_circuit_to_port_graph(circuit, edges)
if nx.is_planar(G):
pos = nx.planar_layout(G)
else:
pos = nx.spring_layout(G)
for n in pos:
pos[n] = pos[n]*scale
# Use a colormap to generate a color for each node
# And use that color to decide on white vs. black text for
# each node
node_color_map = nx.get_node_attributes(G, "color")
node_color = np.array([node_color_map[n] for n in G.nodes()]).astype(float)
node_color /= np.max(node_color)
cmap = matplotlib.cm.get_cmap('viridis')
black_text_nodes = {}
white_text_nodes = {}
node_color_mapped = []
for i, n in enumerate(G.nodes()):
color = cmap(node_color[i])
node_color_mapped.append(color)
if black_or_white_text(color):
black_text_nodes[n] = pos[n]
else:
white_text_nodes[n] = pos[n]
nx.draw(G, pos, node_color=node_color, node_size=node_size)
# Draw black text nodes
Gk = nx.Graph()
for n in black_text_nodes:
Gk.add_node(n)
nx.draw_networkx_labels(Gk, black_text_nodes,
font_color='k', font_size=font_size)
# Draw white text nodes
Gw = nx.Graph()
for n in white_text_nodes:
Gw.add_node(n)
nx.draw_networkx_labels(Gw, white_text_nodes,
font_color='w', font_size=font_size)
if out != "":
plt.savefig(out)
plt.close()
else:
return fig
def print_bare_graphs(all_graphs: list):
"""Plots unlabeled graphs from graph list returned
from get_graphs_from_file() function
Args:
all_graphs (list): a list of unlabeled graphs
"""
for i, G in enumerate(all_graphs):
draw_basegraph(G, f"graph index: {i}")
def draw_all_basegraphs(base_path: str, n_start: int = 2, n_end: int = 4):
"""
Draws all the basegraphs for nodes with number
n_start through n_end.
Args:
base_path (str): folder to save images in
n_start (int, optional): start drawing graphs
for this number of nodes
n_end (int, optional): stop drawing graphs for this
number of nodes (inclusive)
Returns:
dict: positions of nodes for each plot
"""
pos = {}
for n_nodes in range(n_start, n_end + 1):
pos[n_nodes] = []
all_graphs = utils.get_basegraphs(n_nodes)
for i, G in enumerate(all_graphs):
fname = str(Path(base_path, f"n{n_nodes}_g{i}.svg"))
title = f"{n_nodes} nodes, graph {i}"
f, p = draw_basegraph(G, title, fname)
for k in p:
p[k] = np.round(p[k], 2)
pos[n_nodes].append(p)
return pos
[docs]def draw_basegraph(G: nx.Graph, title: str = "",
savename: str = None, **kwargs):
"""Plots unlabeled graphs from graph list returned
from get_graphs_from_file() function
Args:
G (nx.graph): networkx graph to plot
title (str): title for plot
savename (str): location to save the plots
pos (dict): dictionary of node labels -> positions
Returns:
tuple: figure, positioning (x, y) of nodes
"""
edges = [x for x in G.edges()]
n_nodes = G.number_of_nodes()
graph_index = utils.edges_to_graph_index(edges)
if n_nodes in G_POS_BG:
def_pos = G_POS_BG[n_nodes][graph_index]
else:
if nx.is_planar(G):
def_pos = nx.planar_layout(G)
else:
def_pos = nx.spring_layout(G)
pos = kwargs.get("pos", def_pos)
f = plt.figure(figsize=(3, 3))
plt.tight_layout()
nx.draw_networkx(G, pos=pos, node_color="#FFFFFF", linewidths=1, width=2, edgecolors="#000000",
with_labels=False, node_size=800)
plt.gca().set_axis_off()
plt.gca().set_aspect('equal')
plt.tight_layout()
if savename is not None:
plt.savefig(savename, bbox_inches = "tight")
return f, pos
def draw_all_qubits(file: str, n_nodes: int, out_dir: str,
layout: str = 'fixed', format: str = ".svg"):
if n_nodes < 5:
scale = 4.0
else:
scale = 4.0 + n_nodes - 3
# So plots don't pop up
matplotlib.use('agg')
df = utils.get_unique_qubits(file, n_nodes)
for uid, row in tqdm(df.iterrows(), total=df.shape[0]):
circuit = row.circuit
edges = row.edges
graph_index = row.graph_index
viz.draw_circuit_diagram(circuit, edges,
out=Path(out_dir, f"{uid}{format}"),
layout=layout, graph_index=graph_index,
scale=scale)
plt.close()
[docs]def draw_circuit_diagram(circuit: list, edges: list,
out: str = "",
scale: float = 4.0, layout: str = 'fixed',
spread: float = 2/5, graph_index: int = None) -> None:
"""
Draw the circuit diagram using `schemdraw`.
For parallel elements, connections are split 1/4 of the way along and fan out
to display parallel elements. Non-planar graphs are not adjusted to avoid overlap.
Parameters
----------
circuit : list
A list of element labels for the desired circuit.
Example: ``[["J"], ["L", "J"], ["C"]]``.
edges : list
A list of edge connections for the desired circuit.
Example: ``[(0,1), (0,2), (1,2)]``.
out : str
Filename to save the plot. If an empty string (``""``) is provided,
the plot is displayed interactively.
scale : float
Scaling factor for the networkx positions to spread out the plots if needed.
layout : str
Options for graph layouts.
**"spring"** produces aesthetically pleasing circuits but may lead to
overlapping elements, even for planar graphs.
spread : float
Fraction of edge length used to fan out parallel components.
graph_index : int
Graph number for a fixed layout. Can also be inferred from `edges`.
Returns
-------
None
Displays the plot interactively if `out == ""`.
Examples
--------
>>> draw_circuit_diagram(
>>> circuit=[["J"], ["L", "J"], ["C"]],
>>> edges=[(0,1), (0,2), (1,2)],
>>> out="circuit.png",
>>> layout="spring",
>>> scale=1.5
>>> )
"""
edges = utils.zero_start_edges(edges)
elem_dict = {
'C': {'default_unit': 'GHz', 'default_value': 0.2},
'L': {'default_unit': 'GHz', 'default_value': 1.0},
'J': {'default_unit': 'GHz', 'default_value': 5.0},
}
# Get rid of any labels on elements
new_circuit = []
for elems in circuit:
new_elems = []
for elem in elems:
for base_elem in elem_dict:
if base_elem in elem:
new_elems.append(base_elem)
continue
new_circuit.append(tuple(new_elems))
circuit = new_circuit
params = utils.gen_param_dict(circuit, edges, elem_dict)
G = utils.convert_circuit_to_graph(circuit, edges,
params=params)
# Get graph index
if graph_index is None and layout == "fixed":
graph_index = utils.edges_to_graph_index(edges)
# Get layout of vertices
if layout == 'planar':
if nx.is_planar(G):
pos = nx.planar_layout(G)
else:
print("Not a planar graph... reverting to spring layout")
pos = nx.spring_layout(G)
elif layout == 'spring':
pos = nx.spring_layout(G)
elif layout == 'fixed':
pos = G_POS[G.number_of_nodes()][graph_index]
# Scale
scaled_pos = {}
for k in pos:
scaled_pos[k] = pos[k]*scale
# Define the circuit elements
elem_bank = {
'C': lambda: elm.Capacitor(),
'L': lambda: elm.Inductor2(),
'J': lambda: elm.Josephson()
}
# Calculate minimum pairwise distance for plotting
pdist_mat = pdist(np.vstack([x for x in scaled_pos.values()]))
min_r = np.min(pdist_mat[pdist_mat > 0])
with schemdraw.Drawing() as d:
for n0 in G.nodes():
# Draw a dot at every node
d.add(DotCustom(radius=0.1, fill="#FFFFFF", color="#000000",
lw=1).at(scaled_pos[n0]))
# d.add(schemdraw.segments.SegmentCircle(scaled_pos[n0],
# radius=0.2,
# color="#000000",
# fill="#FFFFFF",
# lw=0.01))
for n1 in G.nodes():
if n0 != n1 and n0 < n1 and G.has_edge(n0, n1):
edgesBetweenNodes = G[n0][n1]
# Single element
nEdges = len(edgesBetweenNodes)
if nEdges == 1:
edge = edgesBetweenNodes[0]
# Get the endpoints and specific circuit element
# that correspond to this edge
x0 = scaled_pos[n0]
x1 = scaled_pos[n1]
d.add(elem_bank[edge['element']]().endpoints(x0, x1))
# Parallel Elements
else:
# Draw a line perpendicular to the connection
# Split way there and spread out spread
# out and do them in parallel
split = 1/4
x0 = scaled_pos[n0]
x1 = scaled_pos[n1]
# Unit vectors along displacement and
# perpendicular to it
rhat = x1-x0
r = np.linalg.norm(rhat)
rhat /= r
norm = np.array([-rhat[1], rhat[0]])
perp_step = (spread*min_r/nEdges)*norm
# Anchor points for splitting
adj = split*rhat*(r-min_r)
a0 = x0+split*rhat*r+adj
a1 = x1-split*rhat*r-adj
d.add(elm.Wire().at(x0).to(a0))
d.add(elm.Wire().at(x1).to(a1))
# Start at 0 and go out integer
# steps if odd numberof edges.
# For even number of edges do 1/2 integer steps
if nEdges % 2 == 0:
maxSpread = (perp_step*(nEdges/2-1/2))
d.add(elm.Wire().at(a0).to(a0-maxSpread))
d.add(elm.Wire().at(a0).to(a0+maxSpread))
d.add(elm.Wire().at(a1).to(a1-maxSpread))
d.add(elm.Wire().at(a1).to(a1+maxSpread))
for edgeNum in range(int(nEdges/2)):
edge1 = edgesBetweenNodes[edgeNum]
edge2 = edgesBetweenNodes[nEdges-1-edgeNum]
d.add(elem_bank[edge1['element']]().endpoints(
a0+perp_step*(edgeNum+1/2),
a1+perp_step*(edgeNum+1/2)))
d.add(elem_bank[edge2['element']]().endpoints(
a0-perp_step*(edgeNum+1/2),
a1-perp_step*(edgeNum+1/2)))
else:
maxSpread = (perp_step*(nEdges-1)/2)
d.add(elm.Wire().at(a0).to(a0-maxSpread))
d.add(elm.Wire().at(a0).to(a0+maxSpread))
d.add(elm.Wire().at(a1).to(a1-maxSpread))
d.add(elm.Wire().at(a1).to(a1+maxSpread))
# 0th one straight across
d.add(elem_bank[edgesBetweenNodes[0]
['element']]().endpoints(a0, a1))
# Others mirrored across -- work in pairs
for edgeNum in range(1, int((nEdges-1)/2 + 1)):
edge1 = edgesBetweenNodes[edgeNum]
edge2 = edgesBetweenNodes[nEdges-edgeNum]
d.add(elem_bank[edge1['element']]().endpoints(
a0+perp_step*edgeNum,
a1+perp_step*edgeNum))
d.add(elem_bank[edge2['element']]().endpoints(
a1-perp_step*edgeNum,
a0-perp_step*edgeNum))
if out != "":
d.save(out)
if out != "":
plt.close()
else:
plt.show()
class DotCustom(schemdraw.elements.Element):
''' Connection Dot
Keyword Args:
radius: Radius of dot [default: 0.075]
open: Draw as an open circle [default: False]
'''
_element_defaults = {
'radius': 0.075,
'open': False}
def __init__(self, *,
radius,
fill,
**kwargs):
super().__init__(**kwargs)
fill = fill
self.anchors['start'] = (0, 0)
self.anchors['center'] = (0, 0)
self.anchors['end'] = (0, 0)
self.elmparams['drop'] = (0, 0)
self.elmparams['theta'] = 0
self.elmparams['zorder'] = 4
self.elmparams['fill'] = fill
self.segments.append(schemdraw.segments.SegmentCircle((0, 0),
self.params['radius'],
**kwargs))