# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
from __future__ import annotations
from itertools import chain
from types import MappingProxyType
import matplotlib.pyplot as plt
import numpy as np
import scipp as sc
from matplotlib.collections import LineCollection
from .chopper import ChopperReading
from .component import ComponentReading
from .detector import DetectorReading
from .source import SourceReading
from .utils import Plot, extract_component_group, one_mask
def _get_rays(
components: list[ComponentReading], pulse: int, inds: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
x = []
y = []
c = []
data = components[0].data["pulse", pulse]
xstart = data.coords["toa"].values[inds]
ystart = np.full_like(xstart, components[0].distance.value)
color = data.coords["wavelength"].values[inds]
for comp in components[1:]:
xend = comp.data["pulse", pulse].coords["toa"].values[inds]
yend = np.full_like(xend, comp.distance.value)
x.append([xstart, xend])
y.append([ystart, yend])
c.append(color)
xstart, ystart = xend, yend
color = comp.data["pulse", pulse].coords["wavelength"].values[inds]
return (
np.array(x).transpose((0, 2, 1)),
np.array(y).transpose((0, 2, 1)),
np.array(c),
)
def _add_rays(
ax: plt.Axes,
x: np.ndarray,
y: np.ndarray,
color: np.ndarray | str,
cbar: bool = True,
cmap: str = "gist_rainbow_r",
vmin: float | None = None,
vmax: float | None = None,
cax: plt.Axes | None = None,
zorder: int = 1,
):
coll = LineCollection(np.stack((x, y), axis=2), zorder=zorder)
if isinstance(color, str):
coll.set_color(color)
else:
coll.set_cmap(plt.colormaps[cmap])
coll.set_array(color.ravel())
coll.set_norm(plt.Normalize(vmin, vmax))
if cbar:
cb = plt.colorbar(coll, ax=ax, cax=cax)
cb.ax.yaxis.set_label_coords(-0.9, 0.5)
cb.set_label("Wavelength [Å]")
ax.add_collection(coll)
[docs]
class Result:
"""
Result of a simulation.
Parameters
----------
source:
The source of neutrons.
results:
The state of neutrons at each component in the model.
"""
[docs]
def __init__(self, source: SourceReading, readings: dict[str, dict]):
self._source = source
self._components = MappingProxyType(readings)
@property
def choppers(self) -> MappingProxyType[str, ChopperReading]:
"""
A dictionary of the choppers in the instrument.
"""
return extract_component_group(self._components, "chopper")
@property
def detectors(self) -> MappingProxyType[str, DetectorReading]:
"""
A dictionary of the detectors in the instrument.
"""
return extract_component_group(self._components, "detector")
@property
def samples(self) -> MappingProxyType[str, ComponentReading]:
"""
A dictionary of the samples in the instrument.
"""
return extract_component_group(self._components, "sample")
@property
def source(self) -> SourceReading:
"""The source of neutrons."""
return self._source
def __iter__(self):
return iter(self._components)
def __getitem__(self, name: str) -> ComponentReading:
return self._components[name]
[docs]
def plot(
self,
visible_rays: int = 1000,
blocked_rays: int = 0,
figsize: tuple[float, float] | None = None,
ax: plt.Axes | None = None,
cax: plt.Axes | None = None,
cbar: bool = True,
cmap: str = "gist_rainbow_r",
seed: int | None = None,
vmin: float | None = None,
vmax: float | None = None,
title: str | None = None,
) -> Plot:
"""
Plot the time-distance diagram for the instrument, including the rays of
neutrons that make it to the furthest detector.
As plotting many lines can be slow, the number of rays to plot can be
limited by setting ``visible_rays``.
In addition, it is possible to also plot the rays that are blocked by
choppers along the flight path by setting ``blocked_rays > 0``.
Parameters
----------
visible_rays:
Maximum number of rays to plot.
blocked_rays:
Number of blocked rays to plot.
figsize:
Figure size.
ax:
Axes to plot on.
cax:
Axes to use for the colorbar.
cbar:
Show a colorbar for the wavelength if ``True``.
cmap:
Colormap to use for the wavelength colorbar.
seed:
Random seed for reproducibility.
vmin:
Minimum value for the colorbar.
vmax:
Maximum value for the colorbar.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
components = sorted(
chain((self.source,), self._components.values()), key=lambda c: c.distance
)
furthest_component = components[-1]
rng = np.random.default_rng(seed)
# Make ids for neutrons per pulse, instead of using their id coord
ids = np.arange(self.source.neutrons)
rays = {"x": [], "y": [], "color": []}
for i in range(self._source.data.sizes["pulse"]):
# Plot visible rays
blocked = one_mask(furthest_component.data["pulse", i].masks).values
nblocked = int(blocked.sum())
if nblocked < self.source.neutrons:
inds = rng.choice(
ids[~blocked],
size=min(self.source.neutrons - nblocked, visible_rays),
replace=False,
)
x, y, c = _get_rays(components, pulse=i, inds=inds)
rays["x"].append(x)
rays["y"].append(y)
rays["color"].append(c)
# Plot blocked rays
inds = rng.choice(
ids[blocked], size=min(blocked_rays, nblocked), replace=False
)
x, y, _ = _get_rays(components, pulse=i, inds=inds)
blocked_by_others = np.stack(
[
comp.data["pulse", i].masks["blocked_by_others"].values[inds]
for comp in components[1:]
],
axis=1,
).T
line_selection = np.broadcast_to(
blocked_by_others.reshape((*blocked_by_others.shape, 1)), x.shape
)
x[line_selection] = np.nan
y[line_selection] = np.nan
_add_rays(
ax=ax,
x=x.reshape((-1, 2)),
y=y.reshape((-1, 2)),
color="lightgray",
zorder=-1,
)
# Plot pulse
self.source.plot_on_time_distance_diagram(ax, pulse=i)
# Add coloured rays in one go so that they share the same colorbar, thus
# enabling using zoom on the colorbar to select a wavelength range across all
# pulses.
if len(rays["x"]) > 0:
wavelengths = sc.DataArray(
data=furthest_component.data.coords["wavelength"],
masks=furthest_component.data.masks,
)
if vmin is None:
vmin = wavelengths.nanmin().value
if vmax is None:
vmax = wavelengths.nanmax().value
_add_rays(
ax=ax,
x=np.concatenate([r.reshape((-1, 2)) for r in rays["x"]], axis=0),
y=np.concatenate([r.reshape((-1, 2)) for r in rays["y"]], axis=0),
color=np.concatenate([r.ravel() for r in rays["color"]], axis=0),
cbar=cbar,
cmap=cmap,
vmin=vmin,
vmax=vmax,
cax=cax,
)
if furthest_component.toa.data.sum().value > 0:
toa_max = furthest_component.toa.nanmax().value
else:
toa_max = furthest_component.toa.data.coords["toa"].nanmax().value
# Plot components
for comp in self._components.values():
comp.plot_on_time_distance_diagram(ax=ax, tmax=toa_max)
dx = 0.05 * toa_max
ax.set(xlabel="Time [μs]", ylabel="Distance [m]")
ax.set_xlim(0 - dx, toa_max + dx)
if figsize is None:
inches = fig.get_size_inches()
fig.set_size_inches((min(inches[0] * self.source.pulses, 12.0), inches[1]))
fig.tight_layout()
if title is not None:
ax.set_title(title)
return Plot(fig=fig, ax=ax)
def __repr__(self) -> str:
out = (
f"Result:\n Source: {self.source.pulses} pulses, "
f"{self.source.neutrons} neutrons per pulse.\n"
)
groups = {}
for comp in self._components.values():
if comp.kind not in groups:
groups[comp.kind] = []
groups[comp.kind].append(comp)
for group, comps in groups.items():
out += f" {group.capitalize()}s:\n"
for comp in sorted(comps, key=lambda c: c.distance):
out += f" {comp.name}: {comp._repr_stats()}\n"
return out
def __str__(self) -> str:
return self.__repr__()
[docs]
def to_nxevent_data(self, key: str | None = None) -> sc.DataArray:
"""
Convert a detector reading to event data that resembles event data found in a
NeXus file.
Parameters
----------
key:
Name of the detector. If ``None``, all detectors are included.
"""
start = sc.datetime("2024-01-01T12:00:00.000000")
period = sc.reciprocal(self.source.frequency)
detectors = self.detectors
keys = list(detectors.keys()) if key is None else [key]
event_data = []
for name in keys:
raw_data = detectors[name].data.flatten(to="event")
events = (
raw_data[~raw_data.masks["blocked_by_others"]]
.copy()
.drop_masks("blocked_by_others")
)
events.coords["distance"] = sc.broadcast(
events.coords["distance"], sizes=events.sizes
).copy()
event_data.append(events)
event_data = sc.concat(event_data, dim=event_data[0].dim)
dt = period.to(unit=event_data.coords["toa"].unit)
event_time_zero = (dt * (event_data.coords["toa"] // dt)).to(dtype=int) + start
event_data.coords["event_time_zero"] = event_time_zero
event_data.coords["event_time_offset"] = event_data.coords.pop(
"toa"
) % period.to(unit=dt.unit)
out = (
event_data.drop_coords(["speed", "birth_time", "wavelength"])
.group("distance")
.rename_dims(distance="detector_number")
)
out.coords["Ltotal"] = out.coords.pop("distance")
return out
@property
def data(self) -> sc.DataGroup:
"""
Get the data for the source, choppers, and detectors, as a DataGroup.
The components are sorted by distance.
"""
out = {"source": self.source.data}
components = sorted(
chain(self.choppers.values(), self.detectors.values()),
key=lambda c: c.distance.value,
)
for comp in components:
out[comp.name] = comp.data
return sc.DataGroup(out)