Source code for toolbox_scs.detectors.hrixs

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import leastsq
from scipy.optimize import curve_fit
from scipy.signal import fftconvolve
import xarray as xr

from extra_data import open_run
import pasha as psh
import toolbox_scs as tb


__all__ = [
    'hRIXS',
    'MaranaX',
]


# -----------------------------------------------------------------------------
# Curvature

def find_curvature(image, frangex=None, frangey=None,
                   deg=2, axis=1, offset=100):
    # Resolve arguments
    x_range = (0, image.shape[1])
    if frangex is not None:
        x_range = (max(frangex[0], x_range[0]), min(frangex[1], x_range[1]))
    y_range = (0, image.shape[0])
    if frangex is not None:
        y_range = (max(frangey[0], y_range[0]), min(frangey[1], y_range[1]))

    axis_range = y_range if axis == 1 else x_range
    axis_dim = image.shape[axis - 1]

    # Get kernel
    integral = image[slice(*y_range), slice(*x_range)].mean(axis=axis)
    roi = np.ones([axis_range[1] - axis_range[0], axis_dim])
    ref = roi * integral[:, np.newaxis]

    # Get sliced image
    slice_ = [slice(None), slice(None)]
    slice_[axis - 1] = slice(max(axis_range[0] - offset, 0),
                             min(axis_range[1] + offset, axis_dim))
    sliced = image[tuple(slice_)]
    if axis == 0:
        sliced = sliced.T

    # Get curvature factor from cross correlation
    crosscorr = fftconvolve(sliced,
                            ref[::-1, :],
                            axes=0, )
    shifts = np.argmax(crosscorr, axis=0)
    curv = np.polyfit(np.arange(axis_dim), shifts, deg=deg)
    return curv[:-1][::-1]


def find_curvature(img, args, plot=False, **kwargs):
    def parabola(x, a, b, c, s=0, h=0, o=0):
        return (a*x + b)*x + c
    def gauss(y, x, a, b, c, s, h, o=0):
        return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
    x = np.arange(img.shape[1])[None, :]
    y = np.arange(img.shape[0])[:, None]

    if plot:
        plt.figure(figsize=(10,10))
        plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs)
        plt.plot(x[0, :], parabola(x[0, :], *args))

    args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args)

    if plot:
        plt.plot(x[0, :], parabola(x[0, :], *args))
    return args


def correct_curvature(image, factor=None, axis=1):
    if factor is None:
        return

    if axis == 1:
        image = image.T

    ydim, xdim = image.shape
    x = np.arange(xdim + 1)
    y = np.arange(ydim + 1)
    xx, yy = np.meshgrid(x[:-1] + 0.5, y[:-1] + 0.5)
    xxn = xx - factor[0] * yy - factor[1] * yy ** 2
    ret = np.histogramdd((xxn.flatten(), yy.flatten()),
                         bins=[x, y],
                         weights=image.flatten())[0]

    return ret if axis == 1 else ret.T


def get_spectrum(image, factor=None, axis=0,
                 pixel_range=None, energy_range=None, ):
    start, stop = (0, image.shape[axis - 1])
    if pixel_range is not None:
        start = max(pixel_range[0] or start, start)
        stop = min(pixel_range[1] or stop, stop)

    edge = image.sum(axis=axis)[start:stop]
    bins = np.arange(start, stop + 1)
    centers = (bins[1:] + bins[:-1]) * 0.5
    if factor is not None:
        centers, edge = calibrate(centers, edge,
                                  factor=factor,
                                  range_=energy_range)

    return centers, edge


# -----------------------------------------------------------------------------
# Energy calibration


def energy_calibration(channels, energies):
    return np.polyfit(channels, energies, deg=1)


def calibrate(x, y=None, factor=None, range_=None):
    if factor is not None:
        x = np.polyval(factor, x)

    if y is not None and range_ is not None:
        start = np.argmin(np.abs((x - range_[0])))
        stop = np.argmin(np.abs((x - range_[1])))
        # Calibrated energies have a different direction
        x, y = x[stop:start], y[stop:start]

    return x, y


# -----------------------------------------------------------------------------
# Gaussian-related functions


FWHM_COEFF = 2 * np.sqrt(2 * np.log(2))


def gaussian_fit(x_data, y_data, offset=0):
    """
    Centre-of-mass and width. Lifted from image_processing.imageCentreofMass()
    """

    x0 = np.average(x_data, weights=y_data)
    sx = np.sqrt(np.average((x_data - x0) ** 2, weights=y_data))

    # Gaussian fit
    baseline = y_data.min()
    p_0 = (y_data.max(), x0 + offset, sx, baseline)
    try:
        p_f, _ = curve_fit(gauss1d, x_data, y_data, p_0, maxfev=10000)
        return p_f
    except (RuntimeError, TypeError) as e:
        print(e)
        return None


def gauss1d(x, height, x0, sigma, offset):
    return height * np.exp(-0.5 * ((x - x0) / sigma) ** 2) + offset


def to_fwhm(sigma):
    return abs(sigma * FWHM_COEFF)


def decentroid(res):
    res = np.array(res)
    ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int))
    for cy, cx in res:
        if cx > 0 and cy > 0:
            ret[int(cy), int(cx)] += 1
    return ret


[docs]class hRIXS: """The hRIXS analysis, especially curvature correction The objects of this class contain the meta-information about the settings of the spectrometer, not the actual data, except possibly a dark image for background subtraction. The actual data is loaded into `xarray`s, and stays there. Attributes ---------- PROPOSAL: int the number of the proposal DETECTOR: str the detector to be used. Can be ['hRIXS_det', 'MaranaX'] defaults to 'hRIXS_det' for backward-compatibility. X_RANGE: slice the slice to take in the dispersive direction, in pixels. Defaults to the entire width. Y_RANGE: slice the slice to take in the energy direction THRESHOLD: float pixel counts above which a hit candidate is assumed, for centroiding. use None if you want to give it in standard deviations instead. STD_THRESHOLD: same as THRESHOLD, in standard deviations. DBL_THRESHOLD: threshold controling whether a detected hit is considered to be a double hit. BINS: int the number of bins used in centroiding CURVE_A, CURVE_B: float the coefficients of the parabola for the curvature correction USE_DARK: bool whether to do dark subtraction. Is initially `False`, magically switches to `True` if a dark has been loaded, but may be reset. ENERGY_INTERCEPT, ENERGY_SLOPE: The calibration from pixel to energy FIELDS: the fields to be loaded from the data. Add additional fields if so desired. Example ------- proposal = 3145 h = hRIXS(proposal) h.Y_RANGE = slice(700, 900) h.CURVE_B = -3.695346575286939e-07 h.CURVE_A = 0.024084479232443695 h.ENERGY_SLOPE = 0.018387 h.ENERGY_INTERCEPT = 498.27 h.STD_THRESHOLD = 3.5 """
[docs] DETECTOR_FIELDS = { 'hRIXS_det': ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm', 'nrj'], 'MaranaX': ['MaranaX', 'nrj'],
} def __init__(self, proposalNB, detector='MaranaX'): self.PROPOSAL = proposalNB self.DETECTOR = detector assert detector in self.DETECTOR_FIELDS # image range self.X_RANGE = np.s_[:] self.Y_RANGE = np.s_[:] # centroid # centroid_one threshold self.THRESHOLD = None # pixel counts above which a hit candidate is assumed self.STD_THRESHOLD = 3.5 # same as THRESHOLD, in standard deviations self.DBL_THRESHOLD = 0.1 # factor used for double hits in centroid_one # centroid_two threshold self.CENTROID_THRESHOLD = [0.2, 1] self.CURVE_A = 0 # curvature parameters as determined elsewhere self.CURVE_B = 0 # integral self.BINS = 100 self.METHOD = 'centroid' # ['centroid', 'integral'] self.USE_DARK = False # Ignore double hits self.AVOID_DBL = False self.ENERGY_INTERCEPT = 0 self.ENERGY_SLOPE = 1 self.FIELDS = self.DETECTOR_FIELDS[detector]
[docs] def set_params(self, **params): for key, value in params.items(): setattr(self, key.upper(), value)
[docs] def get_params(self, *params): if not params: params = ('proposal', 'x_range', 'y_range', 'threshold', 'std_threshold', 'dbl_threshold', 'curve_a', 'curve_b', 'bins', 'method', 'fields') return {param: getattr(self, param.upper()) for param in params}
[docs] def from_run(self, runNB, proposal=None, extra_fields=(), drop_first=False, subset=None): """load a run Load the run `runNB`. A thin wrapper around `toolbox.load`. Parameters ---------- drop_first: bool if True, the first image in the run is removed from the dataset. Example ------- data = h.from_run(145) # load run 145 data1 = h.from_run(145) # load run 145 data2 = h.from_run(155) # load run 155 data = xarray.concat([data1, data2], 'trainId') # combine both """ if proposal is None: proposal = self.PROPOSAL if drop_first: subset = slice(1, None) run, data = tb.load(proposal, runNB=runNB, subset=subset, fields=self.FIELDS + list(extra_fields)) return data
[docs] def load_dark(self, runNB, proposal=None): """load a dark run Load the dark run `runNB` from `proposal`. The latter defaults to the current proposal. The dark is stored in this `hRIXS` object, and subsequent analyses use it for background subtraction. Example ------- h.load_dark(166) # load dark run 166 """ data = self.from_run(runNB, proposal) self.dark_image = data[self.DETECTOR].mean(dim='trainId') self.USE_DARK = True
[docs] def find_curvature(self, runNB, proposal=None, plot=True, args=None, **kwargs): """find the curvature correction coefficients The hRIXS has some abberations which leads to the spectroscopic lines being curved on the detector. We approximate these abberations with a parabola for later correction. Load a run and determine the curvature. The curvature is set in `self`, and returned as a pair of floats. Parameters ---------- runNB: int the run number to use proposal: int the proposal to use, default to the current proposal plot: bool whether to plot the found curvature onto the data args: pair of float, optional a starting value to prime the fitting routine Example ------- h.find_curvature(155) # use run 155 to fit the curvature """ data = self.from_run(runNB, proposal) image = data[self.DETECTOR].sum(dim='trainId') \ .values[self.X_RANGE, self.Y_RANGE].T if args is None: spec = (image - image[:10, :].mean()).mean(axis=1) mean = np.average(np.arange(len(spec)), weights=spec) args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3, spec.max(), image.mean()) args = find_curvature(image, args, plot=plot, **kwargs) self.CURVE_B, self.CURVE_A, *_ = args return self.CURVE_A, self.CURVE_B
[docs] def centroid_one(self, image): """find the position of photons with sub-pixel precision A photon is supposed to have hit the detector if the intensity within a 2-by-2 square exceeds a threshold. In this case the position of the photon is calculated as the center-of-mass in a 4-by-4 square. Return the list of x, y coordinate pairs, corrected by the curvature. """ base = image.mean() corners = image[1:, 1:] + image[:-1, 1:] \ + image[1:, :-1] + image[:-1, :-1] if self.THRESHOLD is None: threshold = corners.mean() + self.STD_THRESHOLD * corners.std() else: threshold = self.THRESHOLD # Threshold for double photons chosen to be the same ratio to single # photons as found in the ESRF method SpotHIGH=self.DBL_THRESHOLD*threshold if self.AVOID_DBL: SpotHIGH = 100000 middle = corners[1:-1, 1:-1] candidates = ( (middle > threshold) * (middle >= corners[:-2, 1:-1]) * (middle > corners[2:, 1:-1]) * (middle >= corners[1:-1, :-2]) * (middle > corners[1:-1, 2:]) * (middle >= corners[:-2, :-2]) * (middle > corners[2:, :-2]) * (middle >= corners[:-2, 2:]) * (middle > corners[2:, 2:])) cp = np.argwhere(candidates) if len(cp) > 10000: raise RuntimeError( "too many peaks, threshold low or acquisition time too high") res = [] dres = [] for cy, cx in cp: spot = image[cy: cy + 4, cx: cx + 4] - base mx = np.average(np.arange(cx, cx + 4), weights=spot.sum(axis=0)) my = np.average(np.arange(cy, cy + 4), weights=spot.sum(axis=1)) if spot.sum() < SpotHIGH: res.append((mx, my)) else: res.append((mx, my)) res.append((mx, my)) dres.append((mx, my)) return res, dres
[docs] def centroid_two(self, image, energy): """ determine position of photon hits on detector The algrothm is taken from the ESRF RIXS toolbox. The thresholds for determining photon hits are given by the incident photon energy The function returns arrays containing the single and double hits as x and y coordinates """ # Prepare threshold = self.CENTROID_THRESHOLD # Multiplication factor * ADU/photon photons = energy/3.6/1.06 SpotLOW = threshold[0] * photons SpotHIGH = threshold[1] * photons low_th_px = threshold[0] * photons high_th_px = threshold[1] * photons if self.AVOID_DBL: SpotHIGH = 100000 img = image-image.mean() gs = 2 # Find potential hits on a clipped image where 2 rows/columns are excluded # from the edges because centroiding can't be done at the edge cp = (np.argwhere((img[gs//2 : -gs//2, gs//2 : -gs//2] > low_th_px)* (img[gs//2 : -gs//2, gs//2 : -gs//2] < high_th_px))+ np.array([gs//2, gs//2])) if len(cp) > 100000: raise RuntimeError('Threshold to low or acquisition time to long') res = [] dres = [] for cy, cx in cp: spot = img[cy - gs//2 : cy + gs//2 + 1, cx - gs//2 : cx+gs//2 +1] #spot[spot < 0] = 0 if (spot > img[cy, cx]).sum() == 0: mx = np.average(np.arange(cx - gs//2, cx + gs//2 + 1), weights=spot.sum(axis=0)) my = np.average(np.arange(cy - gs//2, cy + gs//2 + 1), weights=spot.sum(axis=1)) if (spot.sum()>=SpotLOW) and (spot.sum()<SpotHIGH): res.append((mx, my)) elif (spot.sum()>SpotHIGH): res.append((mx, my)) res.append((mx, my)) dres.append((mx, my)) return res, dres
[docs] def centroid(self, data, bins=None, method='auto'): """calculate a spectrum by finding the centroid of individual photons This takes the `xarray.Dataset` `data` and returns a copy of it, with a new `xarray.DataArray` named `spectrum` added, which contains the energy spectrum calculated for each hRIXS image. Added a key for switching between algorithims choices are "auto" and "manual" which selects for method for determining whether thresholds there is a photon hit. It changes whether centroid_one or centroid_two is used. Example ------- h.centroid(data) # find photons in all images of the run data.spectrum[0, :].plot() # plot the spectrum of the first image """ if bins is None: bins = self.BINS ret = np.zeros((len(data[self.DETECTOR]), bins)) retd = np.zeros((len(data[self.DETECTOR]), bins)) total_hits = np.zeros((len(data[self.DETECTOR]))) dbl_hits = np.zeros((len(data[self.DETECTOR]))) for i, (image, r, rd) in enumerate(zip(data[self.DETECTOR], ret, retd)): if method=='manual': c, d = self.centroid_one( image.values[self.X_RANGE, self.Y_RANGE]) elif method=='auto': energy = data['nrj'][i].data c, d = self.centroid_two( image.values[self.X_RANGE, self.Y_RANGE], energy) if not len(c): continue rc = np.array(c) r[:], _ = np.histogram( rc[:, 0] - self.parabola(rc[:, 1]), bins=bins, range=(0, self.Y_RANGE.stop - self.Y_RANGE.start)) total_hits[i] = rc.shape[0] rcd = np.array(d) dbl_hits[i] = rcd.shape[0] # Account for case where no double hits are found if rcd.shape[0] == 0: continue else: rd[:], _ = np.histogram( rcd[:, 0] - self.parabola(rcd[:, 1]), bins=bins, range=(0, self.Y_RANGE.stop - self.Y_RANGE.start)) data.coords["energy"] = ( np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) data['spectrum'] = (("trainId", "energy"), ret) data['dbl_spectrum'] = (("trainId", "energy"), retd) data['total_hits'] = ("trainId", total_hits) data['double_hits'] = ("trainId", dbl_hits) return data
[docs] def parabola(self, x): return (self.CURVE_B * x + self.CURVE_A) * x
[docs] def integrate(self, data): """calculate a spectrum by integration This takes the `xarray` `data` and returns a copy of it, with a new dataarray named `spectrum` added, which contains the energy spectrum calculated for each hRIXS image. First the energy that corresponds to each pixel is calculated. Then all pixels within an energy range are summed, where the intensity of one pixel is distributed among the two energy ranges the pixel spans, proportionally to the overlap between the pixel and bin energy ranges. The resulting data is normalized to one pixel, so the average intensity that arrived on one pixel. Example ------- h.integrate(data) # create spectrum by summing pixels data.spectrum[0, :].plot() # plot the spectrum of the first image """ bins = self.Y_RANGE.stop - self.Y_RANGE.start margin = 10 ret = np.zeros((len(data[self.DETECTOR]), bins - 2 * margin)) if self.USE_DARK: dark_image = self.dark_image.values[self.X_RANGE, self.Y_RANGE] images = data[self.DETECTOR].values[:, self.X_RANGE, self.Y_RANGE] x, y = np.ogrid[:images.shape[1], :images.shape[2]] quo, rem = divmod(y - self.parabola(x), 1) quo = np.array([quo, quo + 1]) rem = np.array([rem, 1 - rem]) wrong = (quo < margin) | (quo >= bins - margin) quo[wrong] = margin rem[wrong] = 0 quo = (quo - margin).astype(int).ravel() for image, r in zip(images, ret): if self.USE_DARK: image = image - dark_image r[:] = np.bincount(quo, weights=(rem * image).ravel()) ret /= np.bincount(quo, weights=rem.ravel()) data.coords["energy"] = ( np.arange(self.Y_RANGE.start + margin, self.Y_RANGE.stop - margin) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) data['spectrum'] = (("trainId", "energy"), ret) return data
[docs] aggregators = dict( hRIXS_det=lambda x, dim: x.sum(dim=dim), MaranaX=lambda x, dim: x.sum(dim=dim), Delay=lambda x, dim: x.mean(dim=dim), hRIXS_delay=lambda x, dim: x.mean(dim=dim), hRIXS_norm=lambda x, dim: x.sum(dim=dim), spectrum=lambda x, dim: x.sum(dim=dim), dbl_spectrum=lambda x, dim: x.sum(dim=dim), total_hits=lambda x, dim: x.sum(dim=dim), dbl_hits=lambda x, dim: x.sum(dim=dim), counts=lambda x, dim: x.sum(dim=dim)
)
[docs] def aggregator(self, da, dim): agg = self.aggregators.get(da.name) if agg is None: return None return agg(da, dim=dim)
[docs] def aggregate(self, ds, var=None, dim="trainId"): """aggregate (i.e. mostly sum) all data within one dataset take all images in a dataset and aggregate them and their metadata. For images, spectra and normalizations that means adding them, for others (e.g. delays) adding would not make sense, so we treat them properly. The aggregation functions of each variable are defined in the aggregators attribute of the class. If var is specified, group the dataset by var prior to aggregation. A new variable "counts" gives the number of frames aggregated in each group. Parameters ---------- ds: xarray Dataset the dataset containing RIXS data var: string One of the variables in the dataset. If var is specified, the dataset is grouped by var prior to aggregation. This is useful for sorting e.g. a dataset that contains multiple delays. dim: string the dimension over which to aggregate the data Example ------- h.centroid(data) # create spectra from finding photons agg = h.aggregate(data) # sum all spectra agg.spectrum.plot() # plot the resulting spectrum agg2 = h.aggregate(data, 'hRIXS_delay') # group data by delay agg2.spectrum[0, :].plot() # plot the spectrum for first value """ ds["counts"] = xr.ones_like(ds[dim]) if var is not None: groups = ds.groupby(var) return groups.map(self.aggregate_ds, dim=dim) return self.aggregate_ds(ds, dim)
[docs] def aggregate_ds(self, ds, dim='trainId'): ret = ds.map(self.aggregator, dim=dim) ret = ret.drop_vars([n for n in ret if n not in self.aggregators]) return ret
[docs] def normalize(self, data, which="hRIXS_norm"): """ Adds a 'normalized' variable to the dataset defined as the ration between 'spectrum' and 'which' Parameters ---------- data: xarray Dataset the dataset containing hRIXS data which: string, default="hRIXS_norm" one of the variables of the dataset, usually "hRIXS_norm" or "counts" """ return data.assign(normalized=data["spectrum"] / data[which])
[docs]class MaranaX(hRIXS): """ A spin-off of the hRIXS class: with parallelized centroiding """
[docs] NUM_MAX_HITS = 30
psh.set_default_context('processes', num_workers=20) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._centroid_result = None
[docs] def centroid(self, data, bins=None, **kwargs): # 0.a Setup constants that we might use if bins is None: bins = self.BINS images = data[self.DETECTOR] # 0.b. Allocate output arrays (the naming is a bit meany..) num_images, _, Nx = images.shape self._centroid_result = { 'total_hist': psh.alloc(shape=(num_images, bins)), 'double_hist': psh.alloc(shape=(num_images, bins)), 'total_hits': psh.alloc(shape=(num_images,)), 'double_hits': psh.alloc(shape=(num_images,)), } # 0.c Use a xr.DataArray that pasha understands data_array = images.assign_coords(nrj=data['nrj']) # 1. Calculate with parallelization psh.map(self._centroid_tb_map, data_array) # 2. Finalize: set the results back to the dataset data.coords["energy"] = ( np.linspace(self.Y_RANGE.start or 0, self.Y_RANGE.stop or Nx, bins) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) data['spectrum'] = (("trainId", "energy"), self._centroid_result['total_hist']) data['dbl_spectrum'] = (("trainId", "energy"), self._centroid_result['double_hist']) data['total_hits'] = ("trainId", self._centroid_result['total_hits']) data['double_hits'] = ("trainId", self._centroid_result['double_hits']) return data
[docs] def _centroid_tb_map(self, _, index, data): self._centroid_map(index, image=data.values, energy=data.coords['nrj'].item())
[docs] def _centroid_map(self, index, *, image, energy): total, double = self._centroid_task(index, image, energy) # Check if there are results: don't do anything otherwise if not len(total): return self._histogram_task(index, total, double, default_range=(0, image.shape[1]))
[docs] def _centroid_task(self, index, image, energy): # Calculate the centroid and store it on the allocated output array total, double = self.centroid_two(image[self.X_RANGE, self.Y_RANGE], energy) total, double = np.array(total), np.array(double) total_hits_list = self._centroid_result.get('total_hits_list') if total_hits_list is not None and len(total): total_hits_list[index][:min(len(total), self.NUM_MAX_HITS)] \ = total[:self.NUM_MAX_HITS] double_hits_list = self._centroid_result.get('double_hits_list') if double_hits_list is not None and len(double): double_hits_list[index][:min(len(double), self.NUM_MAX_HITS)] \ = double[:self.NUM_MAX_HITS] return total, double
[docs] def _histogram_task(self, index, total, double, default_range): # Prepare _, bins = self._centroid_result['total_hist'].shape hist_range = (self.Y_RANGE.start or default_range[0], self.Y_RANGE.stop or default_range[1]) # Calculate total spectrum self._centroid_result['total_hist'][index], _ = np.histogram( total[:, 0] - self.parabola(total[:, 1]), bins=bins, range=hist_range) self._centroid_result['total_hits'][index] = len(total) # Calculate double spectrum double_hits = len(double) if double_hits: self._centroid_result['double_hist'][index], _ = np.histogram( double[:, 0] - self.parabola(double[:, 1]), bins=bins, range=hist_range) self._centroid_result['double_hits'][index] = double_hits
[docs] def centroid_from_run(self, runNB, proposal=None, extra_fields=(), drop_first=False, subset=None, bins=None, return_hits=False): """ A combined function of `from_run()` and `centroid()`, which uses `extra_data` and `pasha` to avoid bulk loading of files. """ # 0.a Setup constants that we might use if proposal is None: proposal = self.PROPOSAL if drop_first: subset = slice(1, None) if bins is None: bins = self.BINS run_mnemo = set(self.DETECTOR_FIELDS['MaranaX']) | set(extra_fields) # 0.b. Open the run with extra-data and select the relevant fields run = open_run(proposal, runNB) if subset is not None: run = run.select_trains(subset) # Filter out mnemonics that does not exist in the run files run_mnemo = set([mnemo for mnemo in run_mnemo if self._is_mnemo_in_run(mnemo, run)]) assert set(self.DETECTOR_FIELDS['MaranaX']).issubset(run_mnemo) sources = [self._mnemo_to_prop(mnemo)[0] for mnemo in run_mnemo] selection = run.select([source for source in sources if source in run.all_sources], require_all=True) # 0.c. Allocate output arrays (the naming is a bit meany..) mara_source, mara_key = self._mnemo_to_prop('MaranaX') num_images, _, Nx = selection[mara_source, mara_key].shape self._centroid_result = { 'total_hist': psh.alloc(shape=(num_images, bins)), 'double_hist': psh.alloc(shape=(num_images, bins)), 'total_hits': psh.alloc(shape=(num_images,)), 'double_hits': psh.alloc(shape=(num_images,)), } if return_hits: self._centroid_result['total_hits_list'] = psh.alloc( shape=(num_images, self.NUM_MAX_HITS, 2), fill=np.nan) self._centroid_result['double_hits_list'] = psh.alloc( shape=(num_images, self.NUM_MAX_HITS, 2), fill=np.nan) # 1. Calculate with parallelization psh.map(self._centroid_ed_map, selection) # 2. Finalize: generate the rest of the mnemonics to the dataset # and set the results also. data = xr.merge( [selection.get_array(*self._mnemo_to_prop(mnemo), name=mnemo) for mnemo in run_mnemo - {'MaranaX'}], join='inner') data.coords["energy"] = ( np.linspace(self.Y_RANGE.start or 0, self.Y_RANGE.stop or Nx, bins) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) data['spectrum'] = (("trainId", "energy"), self._centroid_result['total_hist']) data['dbl_spectrum'] = (("trainId", "energy"), self._centroid_result['double_hist']) data['total_hits'] = ("trainId", self._centroid_result['total_hits']) data['double_hits'] = ("trainId", self._centroid_result['double_hits']) if return_hits: data['total_hits_list'] = ( ('trainId', 'hits', 'coord'), self._centroid_result['total_hits_list']) data['double_hits_list'] = ( ('trainId', 'hits', 'coord'), self._centroid_result['double_hits_list']) # Add attributes data.attrs.update( CENTROID_THRESHOLD=self.CENTROID_THRESHOLD, NUM_BINS=bins, CURVATURE_CORRECTION=[self.CURVE_A, self.CURVE_B], ENERGY_CALIBRATION=[self.ENERGY_SLOPE, self.ENERGY_INTERCEPT], Y_RANGE=[self.Y_RANGE.start or 0, self.Y_RANGE.stop or Nx] ) return data
[docs] def _centroid_ed_map(self, _, index, trainId, data): mara_source, mara_key = self._mnemo_to_prop('MaranaX') nrj_source, nrj_key = self._mnemo_to_prop('nrj') self._centroid_map( index, image=data[mara_source][mara_key], energy=data[nrj_source][nrj_key])
@staticmethod
[docs] def _mnemo_to_prop(mnemo): prop = tb.constants.mnemonics[mnemo][0] return prop['source'], prop['key']
[docs] def _is_mnemo_in_run(self, mnemo, run): source, key = self._mnemo_to_prop(mnemo) if source not in run.all_sources: return False return key in run.keys_for_source(source)