"""Module for quantum simulation."""
from dataclasses import dataclass
from functools import partial
from typing import List, Union
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.typing import ArrayLike
from sax.saxtypes import Model
from sax.utils import get_ports
from scipy.stats import multivariate_normal
from simphony.exceptions import ShapeMismatchError
from simphony.simulation import SimDevice, Simulation, SimulationResult
from simphony.utils import dict_to_matrix, xpxp_to_xxpp, xxpp_to_xpxp
[docs]def plot_mode(means, cov, n=100, x_range=None, y_range=None, ax=None, **kwargs):
"""Plots the Wigner function of a single mode state.
Parameters
----------
means : ArrayLike
The means of the X and P quadratures of the quantum state. For example,
a coherent state :math:`\alpha = 3+4i` has means defined as
:math:`\begin{bmatrix} 3 & 4 \\end{bmatrix}'. The shape of the means
must be a length of 2.
cov : ArrayLike
The covariance matrix of the quantum state. For example, all coherent
states has a covariance matrix of :math:`\begin{bmatrix} 1/4 & 0 \\ 0 &
1/4 \\end{bmatrix}`. The shape of the matrix must be 2 x 2.
n : int
The number of points per axis to plot. Default is 100.
x_range : tuple
The range of the x axis to plot as a tuple, (eg. (-5,5)). Defualt
attempts to find the range automatically.
y_range : tuple
The range of the y axis to plot as a tuple, (eg. (-5,5)). Defualt
attempts to find the range automatically.
ax : matplotlib.axes.Axes
The axis to plot on, by default it creates a new figure.
**kwargs :
Keyword arguments to pass to matplotlib.pyplot.contourf.
"""
if ax is None:
fig, ax = plt.subplots()
if x_range is None:
x_range = (
means[0] - 5 * jnp.sqrt(cov[0, 0]),
means[0] + 5 * jnp.sqrt(cov[0, 0]),
)
if y_range is None:
y_range = (
means[1] - 5 * jnp.sqrt(cov[1, 1]),
means[1] + 5 * jnp.sqrt(cov[1, 1]),
)
x_max = jnp.max(jnp.abs(jnp.array(x_range)))
y_max = jnp.max(jnp.abs(jnp.array(y_range)))
r_max = jnp.max(jnp.array((x_max, y_max)))
x_range = (-r_max, r_max)
y_range = (-r_max, r_max)
x = jnp.linspace(x_range[0], x_range[1], n)
y = jnp.linspace(y_range[0], y_range[1], n)
X, Y = jnp.meshgrid(x, y)
pos = jnp.dstack((X, Y))
dist = multivariate_normal(means, cov)
pdf = dist.pdf(pos)
ax.contourf(X, Y, pdf, **kwargs)
ax.set_aspect("equal")
ax.set_xlabel("X")
ax.set_ylabel("P")
return ax
[docs]class QuantumState(SimDevice):
r"""Represents a quantum state in a quantum model as a covariance matrix.
All quantum states are represented in the xpxp convention.
Parameters
----------
means : ArrayLike
The means of the X and P quadratures of the quantum state. For example,
a coherent state :math:`\alpha = 3+4i` has means defined as
:math:`\begin{bmatrix} 3 & 4 \\end{bmatrix}'. The shape of the means
must be 2 * N.
cov : ArrayLike
The covariance matrix of the quantum state. For example, all coherent
states has a covariance matrix of :math:`\begin{bmatrix} 1/4 & 0 \\ 0 &
1/4 \\end{bmatrix}`. The shape of the matrix must be 2 * N x 2 * N.
ports : str or list of str
The ports to which the quantum state is connected. Each mode
corresponds in order to each port provided.
convention : str
The convention of the means and covariance matrix. Default is 'xpxp'.
"""
def __init__(
self,
means: ArrayLike,
cov: ArrayLike,
ports: Union[str, List[str]] = None,
convention: str = "xpxp",
) -> None:
super().__init__(ports)
if ports is None:
self.N = int(len(means) / 2)
else:
self.N = len(ports)
if means.shape != (2 * self.N,):
raise ShapeMismatchError("The shape of the means must be 2 * N.")
if cov.shape != (2 * self.N, 2 * self.N):
raise ShapeMismatchError(
"The shape of the covariance matrix must \
be 2 * N x 2 * N."
)
self.means = means
self.cov = cov
self.convention = convention
[docs] def to_xpxp(self) -> None:
"""Converts the means and covariance matrix to the xpxp convention."""
if self.convention == "xxpp":
self.means = xxpp_to_xpxp(self.means)
self.cov = xxpp_to_xpxp(self.cov)
self.convention = "xpxp"
[docs] def to_xxpp(self) -> None:
"""Converts the means and covariance matrix to the xxpp convention."""
if self.convention == "xpxp":
self.means = xpxp_to_xxpp(self.means)
self.cov = xpxp_to_xxpp(self.cov)
self.convention = "xxpp"
[docs] def modes(self, modes: Union[int, List[int]]):
"""Returns the mean and covariance matrix of the specified modes.
Parameters
----------
modes : int or list
The modes to return.
"""
if not hasattr(modes, "__iter__"):
modes = [modes]
if not all(mode < self.N for mode in modes):
raise ValueError("Modes must be less than the number of modes.")
modes = jnp.array(modes)
inds = jnp.concatenate((modes, (modes + self.N)))
if self.convention == "xpxp":
means = xpxp_to_xxpp(self.means)
cov = xpxp_to_xxpp(self.cov)
means = means[inds]
cov = cov[jnp.ix_(inds, inds)]
means = xxpp_to_xpxp(means)
cov = xxpp_to_xpxp(cov)
else:
means = self.means[inds]
cov = self.cov[jnp.ix_(inds, inds)]
return means, cov
def _add_vacuums(self, n_vacuums: int):
"""Adds vacuum states to the quantum state.
Parameters
----------
n_vacuums : int
The number of vacuum states to add.
"""
N = self.N + n_vacuums
means = jnp.concatenate((self.means, jnp.zeros(2 * n_vacuums)))
cov = 0.25 * jnp.eye(2 * N)
cov = cov.at[: 2 * self.N, : 2 * self.N].set(self.cov)
self.means = means
self.cov = cov
self.N = N
def __repr__(self) -> str:
return (
super().__repr__()
+ f"\nConvention: {self.convention}\nMeans: {self.means}\nCov: \n{self.cov}"
)
[docs] def plot_mode(self, mode, n=100, x_range=None, y_range=None, ax=None, **kwargs):
"""Plots the Wigner function of the specified mode.
Parameters
----------
mode : int
The mode to plot.
n : int
The number of points per axis to plot. Default is 100.
x_range : tuple
The range of the x axis to plot as a tuple, (eg. (-5,5)). Defualt
attempts to find the range automatically.
y_range : tuple
The range of the y axis to plot as a tuple, (eg. (-5,5)). Defualt
attempts to find the range automatically.
ax : matplotlib.axes.Axes
The axis to plot on, by default it creates a new figure.
**kwargs
Keyword arguments to pass to matplotlib.pyplot.contourf.
"""
means, cov = self.modes(mode)
return plot_mode(means, cov, n, x_range, y_range, ax, **kwargs)
[docs]def compose_qstate(*args: QuantumState) -> QuantumState:
"""Combines the quantum states of the input ports into a single quantum
state.
Parameters
----------
args : QuantumState
The quantum states to combine.
"""
N = 0
mean_list = []
cov_list = []
port_list = []
for qstate in args:
qstate.to_xpxp()
N += qstate.N
mean_list.append(qstate.means)
cov_list.append(qstate.cov)
port_list += qstate.ports
means = jnp.concatenate(mean_list)
covs = jnp.zeros((2 * N, 2 * N), dtype=float)
left = 0
for qstate in args:
rowcol = qstate.N * 2 + left
covs = covs.at[left:rowcol, left:rowcol].set(qstate.cov)
left = rowcol
return QuantumState(means, covs, port_list, convention="xpxp")
[docs]class CoherentState(QuantumState):
"""Represents a coherent state in a quantum model as a covariance matrix.
Parameters
----------
port : complex
The port to which the coherent state is connected.
alpha : str
The complex amplitude of the coherent state.
"""
def __init__(self, port: str, alpha: complex) -> None:
self.alpha = alpha
self.N = 1
means = jnp.array([alpha.real, alpha.imag])
cov = jnp.array([[1 / 4, 0], [0, 1 / 4]])
ports = [port]
super().__init__(means, cov, ports)
[docs]class SqueezedState(QuantumState):
"""Represents a squeezed state in a quantum model as a covariance matrix.
Parameters
----------
port : float
The port to which the squeezed state is connected.
r : str
The squeezing parameter of the squeezed state.
phi : float
The squeezing phase of the squeezed state.
alpha: complex, optional
The complex displacement of the squeezed state. Default is 0.
"""
def __init__(self, port: str, r: float, phi: float, alpha: complex = 0) -> None:
self.r = r
self.phi = phi
self.N = 1
means = jnp.array([alpha.real, alpha.imag])
c, s = jnp.cos(phi / 2), jnp.sin(phi / 2)
rot_mat = jnp.array([[c, -s], [s, c]])
cov = (
rot_mat
[docs] @ ((1 / 4) * jnp.array([[jnp.exp(-2 * r), 0], [0, jnp.exp(2 * r)]]))
@ rot_mat.T
)
ports = [port]
super().__init__(means, cov, ports)
class TwoModeSqueezedState(QuantumState):
"""Represents a two mode squeezed state in a quantum model as a covariance
matrix.
This state is described by three parameters: a two-mode squeezing
parameter r, and the two initial thermal occupations n_a and n_b.
Parameters
----------
r : float
The two-mode squeezing parameter of the two mode squeezed state.
n_a : float
The initial thermal occupation of the first mode.
n_b : float
The initial thermal occupation of the second mode.
port_a : str
The port to which the first mode is connected.
port_b : str
The port to which the second mode is connected.
"""
def __init__(
self, r: float, n_a: float, n_b: float, port_a: str, port_b: str
) -> None:
self.r = r
self.n_a = n_a
self.n_b = n_b
self.N = 2
means = jnp.array([0, 0, 0, 0])
ca = (n_a + 1 / 2) * jnp.cosh(r) ** 2 + (n_b + 1 / 2) * jnp.sinh(r) ** 2
cb = (n_b + 1 / 2) * jnp.cosh(r) ** 2 + (n_a + 1 / 2) * jnp.sinh(r) ** 2
cab = (n_a + n_b + 1) * jnp.sinh(r) * jnp.cosh(r)
cov = (
jnp.array(
[[ca, 0, cab, 0], [0, cb, 0, cab], [cab, 0, cb, 0], [0, cab, 0, ca]]
)
/ 2
)
ports = [port_a, port_b]
super().__init__(means, cov, ports)
[docs]class ThermalState(QuantumState):
"""Represents a thermal state in a quantum model as a covariance matrix.
Parameters
----------
port : str
The port to which the thermal state is connected.
nbar : float
The thermal occupation or average photon number of the thermal state.
"""
def __init__(self, port: str, nbar: float) -> None:
self.nbar = nbar
self.N = 1
means = jnp.array([0, 0])
cov = (2 * nbar + 1) / 4 * jnp.eye(2)
ports = [port]
super().__init__(means, cov, ports)
[docs]@dataclass
class QuantumResult(SimulationResult):
"""Quantum simulation results."""
s_params: jnp.ndarray
input_means: jnp.ndarray
input_cov: jnp.ndarray
transforms: jnp.ndarray
means: jnp.ndarray
cov: jnp.ndarray
wl: jnp.ndarray
n_ports: int
[docs] def state(self, wl_ind: int = 0) -> QuantumState:
"""Returns the quantum state at a specific wavelength.
Parameters
----------
wl_ind : int, optional
The wavelength index. Defaults to 0.
"""
means = self.means[wl_ind]
cov = self.cov[wl_ind]
return QuantumState(means, cov, convention="xxpp")
[docs]def plot_quantum_result(
result: QuantumResult,
modes: list = None,
wl_ind: int = 0,
include_loss_modes=False,
):
"""Plot the means and covariance matrix of the quantum result.
Parameters
----------
result : QuantumResult
The quantum simulation result.
modes : list, optional
The modes to plot. Defaults to all modes.
wl_ind : int, optional
The wavelength index to plot. Defaults to 0.
include_loss_modes : bool, optional
Whether to include the loss modes in the plot. Defaults to False.
"""
# create a grid of plots, a single plot for each mode
if modes is None:
n_modes = result.n_ports * 2 if include_loss_modes else result.n_ports
modes = jnp.linspace(0, int(n_modes) - 1, int(n_modes), dtype=int)
n_modes = len(modes)
# make subplots into a square grid
n_rows = int(n_modes**0.5)
n_cols = int(n_modes**0.5)
if n_rows * n_cols < n_modes:
n_rows += 1
n_cols += 1
fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 10))
axs = axs.flatten()
# convert quantum result into quantum state
means = result.means[wl_ind]
cov = result.cov[wl_ind]
for i, mode in enumerate(modes):
mode = jnp.array([mode])
inds = jnp.concatenate((mode, (mode + 1)))
mu = means[inds]
c = cov[jnp.ix_(inds, inds)]
ax = axs[i]
plot_mode(mu, c, x_range=(-6, 6), y_range=(-6, 6), ax=ax)
ax.set_title(f"Mode {mode}")
return axs
[docs]class QuantumSim(Simulation):
"""Quantum simulation.
Parameters
----------
ckt : sax.saxtypes.Model
The circuit to simulate.
wl : ArrayLike
The array of wavelengths to simulate (in microns).
**params
Any other parameters to pass to the circuit.
Examples
--------
>>> sim = QuantumSim(ckt=mzi, wl=wl, top={"length": 150.0}, bottom={"length": 50.0})
"""
def __init__(self, ckt: Model, **kwargs) -> None:
ckt = partial(ckt, **kwargs)
if "wl" not in kwargs:
raise ValueError("Must specify 'wl' (wavelengths to simulate).")
super().__init__(ckt, kwargs["wl"])
[docs] def add_qstate(self, qstate: QuantumState) -> None:
"""Add a quantum state to the simulation.
Parameters
----------
qstate : QuantumState
The quantum state to add.
"""
self.input = qstate
[docs] @staticmethod
def to_unitary(s_params):
"""This method converts s-parameters into a unitary transform by adding
vacuum ports.
The original ports maintain their index while new vacuum ports will
always be the last n_ports.
Parameters
----------
s_params : ArrayLike
s-parameters in the shape of (n_freq, n_ports, n_ports).
Returns
-------
unitary : Array
The unitary s-parameters of the shape (n_freq, 2*n_ports,
2*n_ports).
"""
n_freqs, n_ports, _ = s_params.shape
new_n_ports = n_ports * 2
unitary = jnp.zeros((n_freqs, new_n_ports, new_n_ports), dtype=complex)
for f in range(n_freqs):
unitary = unitary.at[f, :n_ports, :n_ports].set(s_params[f])
unitary = unitary.at[f, n_ports:, n_ports:].set(s_params[f])
for i in range(n_ports):
val = jnp.sqrt(
1 - unitary[f, :n_ports, i].dot(unitary[f, :n_ports, i].conj())
)
unitary = unitary.at[f, n_ports + i, i].set(val)
unitary = unitary.at[f, i, n_ports + i].set(-val)
return unitary
[docs] def run(self) -> QuantumResult:
"""Run the simulation."""
ports = get_ports(self.ckt())
n_ports = len(ports)
# get the unitary s-parameters of the circuit
s_params = dict_to_matrix(self.ckt())
unitary = self.to_unitary(s_params)
# get an array of the indices of the input ports
input_indices = [ports.index(port) for port in self.input.ports]
# create vacuum ports for each extra mode in the unitary matrix
n_modes = unitary.shape[1]
n_vacuum = n_modes - len(input_indices)
self.input._add_vacuums(n_vacuum)
input_indices += [i for i in range(n_modes) if i not in input_indices]
self.input.to_xxpp()
input_means, input_cov = self.input.modes(input_indices)
transforms = []
means = []
covs = []
for wl_ind in range(len(self.wl)):
s_wl = unitary[wl_ind]
transform = jnp.zeros((n_modes * 2, n_modes * 2))
n = n_modes
transform = transform.at[:n, :n].set(s_wl.real)
transform = transform.at[:n, n:].set(-s_wl.imag)
transform = transform.at[n:, :n].set(s_wl.imag)
transform = transform.at[n:, n:].set(s_wl.real)
output_means = transform @ input_means.T
output_cov = transform @ input_cov @ transform.T
# TODO: Possibly implement tolerance for small numbers
# convert small numbers to zero
# output_means[abs(output_means) < 1e-10] = 0
# output_cov[abs(output_cov) < 1e-10] = 0
transforms.append(transform)
means.append(output_means)
covs.append(output_cov)
return QuantumResult(
s_params=s_params,
input_means=input_means,
input_cov=input_cov,
transforms=jnp.stack(transforms),
means=jnp.stack(means),
cov=jnp.stack(covs),
n_ports=n_ports,
wl=self.wl,
)