"""Module for classical simulation."""
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Union
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax
from jax.typing import ArrayLike
from sax.saxtypes import Model
from simphony.simulation import SimDevice, Simulation, SimulationResult
[docs]class Laser(SimDevice):
"""Ideal laser source.
Parameters
----------
ports : str or list of str
The ports to which the laser is connected.
power : float, optional
The power of the laser (in mW), by default 1.0
phase : float, optional
The phase of the laser (in radians), by default 0.0
mod_function : Callable, optional
The modulation function, by default None.
"""
def __init__(
self,
ports: Union[str, List[str]],
power: float = 1.0,
phase: float = 0.0,
mod_function=None,
) -> None:
super().__init__(list(ports))
self.power = power
self.phase = phase
# TODO: Implement mod_function
self.mod_function = mod_function
[docs]class Detector(SimDevice):
"""Ideal photodetector.
Attributes
----------
wl : jnp.ndarray
The wavelengths at which the detector was simulated.
power : jnp.ndarray
The power at each wavelength.
Parameters
----------
port : str
The port to which the detector is connected.
responsivity : float, optional
The responsivity of the detector (in A/W), by default 1.0
"""
def __init__(self, port: str, responsivity: float = 1.0) -> None:
super().__init__(list(port))
if responsivity != 1.0:
warnings.warn("Responsivity is not yet implemented, so it is ignored.")
self.responsivity = responsivity
[docs] def set_result(self, wl: ArrayLike, power: ArrayLike) -> None:
"""Set the result of the detector.
Parameters
----------
wl : ArrayLike
The wavelengths at which the detector was simulated.
power : ArrayLike
The power at each wavelength.
"""
self.wl = wl
self.power = power
[docs] def plot(self, ax=None, **kwargs):
"""Plot the detector response.
Parameters
----------
ax : matplotlib.axes.Axes, optional
The axes on which to plot, by default None (create new axes).
**kwargs
Any other keyword arguments to pass to matplotlib.
"""
if ax is None:
fig, ax = plt.subplots()
ax.plot(self.wl, self.power, **kwargs)
ax.set_xlabel("Wavelength (um)")
ax.set_ylabel("Power (mW)")
return ax
[docs]@dataclass
class ClassicalResult(SimulationResult):
"""Classical simulation results.
Attributes
----------
wl : jnp.ndarray
The wavelengths at which the simulation was run.
sdict : sax.SDict
The S-parameters of the circuit.
detectors : list[Detector]
The detectors and their measurements from the simulation. They are
indexed in the same order as both ``s_params`` and ``output``.
"""
# TODO: Add a function to convert the sdict to a matrix and easily get
# port ordering.
wl: ArrayLike
sdict: sax.SDict
detectors: list[Detector]
[docs]class ClassicalSim(Simulation):
"""Classical simulation."""
def __init__(self, ckt: Model, **kwargs) -> None:
"""Initialize the classical simulation.
Parameters
----------
ckt : sax.saxtypes.Model
The circuit to simulate.
wl : ArrayLike
The array of wavelengths to simulate (in microns).
**params
Any other parameters to pass to the circuit.
Examples
--------
>>> sim = ClassicalSim(ckt=mzi, wl=wl, top={"length": 150.0}, bottom={"length": 50.0})
"""
# TODO: Add shot noise option to classical simulation
ckt = partial(ckt, **kwargs)
if "wl" not in kwargs:
raise ValueError("Must specify 'wl' (wavelengths to simulate).")
super().__init__(ckt, kwargs["wl"])
self.lasers: dict[str, Laser] = {}
self.detectors: dict[str, Detector] = {}
[docs] def add_laser(
self,
ports: Union[str, List[str]],
power: float = 1.0,
phase: float = 0.0,
mod_function: Callable = None,
) -> Laser:
"""Add an ideal laser source.
If multiple ports are specified, the same laser will be connected
to all of them.
Parameters
----------
ports : OPort or list of OPort
The ports to which the laser is connected.
power : float, optional
The power of the laser (in mW), by default 1.0
phase : float, optional
The phase of the laser (in radians), by default 0.0
mod_function : Callable, optional
The modulation function, by default None (not yet implemented).
Returns
-------
Laser
The created laser.
Examples
--------
>>> laser = sim.add_laser(ports=["in"], power=1.0)
"""
ports = [ports] if not isinstance(ports, list) else ports
laser = Laser(ports, power, phase, mod_function)
for port in ports:
self.lasers[port] = laser
return laser
[docs] def add_detector(
self, ports: Union[str, List[str]], responsivity: float = 1.0
) -> List[Detector]:
"""Add an ideal photodetector.
If multiple ports are specified, multiple detectors will be created
and returned.
Parameters
----------
ports : OPort or list of OPort
The ports to which the detector is connected.
responsivity : float, optional
The responsivity of the detector (in A/W), by default 1.0
Returns
-------
list of Detector
A list of the created detector(s) (potentially a list of length 1).
Examples
--------
>>> detector = sim.add_detector(ports=["out"], responsivity=0.8)
"""
ports = [ports] if not isinstance(ports, list) else ports
detectors = []
for port in ports:
detector = Detector(port, responsivity)
self.detectors[port] = detector
detectors.append(detector)
return detectors
[docs] def run(self) -> ClassicalResult:
"""Run the classical simulation.
Returns
-------
ClassicalResult
The simulation results.
"""
S = self.ckt()
sdict = {}
for output_port in self.detectors:
responses = []
for input_port in self.lasers:
signal = jnp.sqrt(self.lasers[input_port].power) * jnp.exp(
1j * self.lasers[input_port].phase
)
responses.append(S[output_port, input_port] * signal)
sdict[output_port] = jnp.sum(jnp.asarray(responses), axis=0)
# # Create input vector from all lasers
# src_v = jnp.zeros((len(self.wl), len(ports)), dtype=jnp.complex64)
# for laser, ports in self.lasers.items():
# idx = [self.ckt._oports.index(port) for port in ports]
# if laser.mod_function is None:
# src_v = src_v.at[:, idx].set(jnp.sqrt(laser.power) * jnp.exp(1j * laser.phase))
# else:
# raise NotImplementedError
# # src_v = src_v.at[:,idx].set(laser.mod_function(self.wl) * jnp.sqrt(laser.power))
for port, detector in self.detectors.items():
power = (jnp.abs(sdict[port]) ** 2) * detector.responsivity
detector.set_result(wl=self.wl, power=power)
result = ClassicalResult(
wl=self.wl,
sdict=sdict,
detectors=self.detectors,
)
return result
# class MonteCarloSim(Simulation):
# """Monte Carlo simulation."""
# def __init__(self, ckt: Circuit, wl: jnp.ndarray) -> None:
# super().__init__(ckt, wl)
# class LayoutAwareSim(Simulation):
# """Layout-aware simulation."""
# def __init__(self, cir: Circuit, wl: jnp.ndarray) -> None:
# super().__init__(cir, wl)
# class SamplingSim(Simulation):
# """Sampling simulation."""
# def __init__(self, ckt: Circuit, wl: jnp.ndarray) -> None:
# super().__init__(ckt, wl)
# class TimeDomainSim(Simulation):
# """Time-domain simulation."""
# def __init__(self, ckt: Circuit, wl: jnp.ndarray) -> None:
# super().__init__(ckt, wl)