Source code for toolbox_scs.routines.boz

"""
Beam splitting Off-axis Zone plate analysis routines.

Copyright (2021, 2022, 2023, 2024) SCS Team.
"""

import time
import datetime
import json
import warnings

import numpy as np
import xarray as xr
import dask.array as da
from scipy.optimize import minimize

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm
from matplotlib.patches import Polygon

from extra_data import open_run
from extra_geom import DSSC_1MGeometry

from toolbox_scs.routines.XAS import xas

try:
    import cupy as cp
    _can_use_gpu = True
except ModuleNotFoundError:
    _can_use_gpu = False
    print('Cupy is not installed in this environment, no access to the GPU')
except ImportError:
    _can_use_gpu = False
    print('Not currently running on a GPU node')


__all__ = [
    'parameters',
    'get_roi_pixel_pos',
    'bad_pixel_map',
    'inspect_dark',
    'histogram_module',
    'inspect_histogram',
    'find_rois',
    'find_rois_from_params',
    'inspect_rois',
    'compute_flat_field_correction',
    'inspect_flat_field_domain',
    'inspect_plane_fitting',
    'plane_fitting_domain',
    'plane_fitting',
    'ff_refine_crit',
    'ff_refine_fit',
    'nl_domain',
    'nl_lut',
    'nl_crit',
    'nl_crit_sk',
    'nl_fit',
    'inspect_nl_fit',
    'snr',
    'inspect_Fnl',
    'inspect_correction',
    'inspect_correction_sk',
    'load_dssc_module',
    'average_module',
    'process_module',
    'process',
    'inspect_saturation'
]


[docs]class parameters(): """Parameters contains all input parameters for the BOZ corrections. This is used in beam splitting off-axis zone plate spectrocopy analysis as well as the during the determination of correction parameters themselves to ensure they can be reproduced. Inputs ------ proposal: int, proposal number darkrun: int, run number for the dark run run: int, run number for the data run module: int, DSSC module number gain: float, number of ph per bin drop_intra_darks: drop every second DSSC frame """ def __init__(self, proposal, darkrun, run, module, gain, drop_intra_darks=True): self.proposal = proposal self.darkrun = darkrun self.run = run self.module = module self.pixel_pos = _get_pixel_pos(self.module) self.gain = gain self.drop_intra_darks = drop_intra_darks self.mask = None self.mask_idx = None self.mean_th = (None, None) self.std_th = (None, None) self.rois = None self.rois_th = None self.ff_type = 'plane' self.flat_field = None self.flat_field_prod_th = (5.0, np.PINF) self.flat_field_ratio_th = (np.NINF, 1.2) self.plane_guess_fit = None self.use_hex = False self.force_mirror = True self.ff_alpha = None self.ff_max_iter = None self._using_gpu = False self.Fnl = None self.nl_alpha = None self.sat_level = None self.nl_max_iter = None # temporary data self.arr_dark = None self.tid_dark = None self.arr = None self.tid = None
[docs] def dask_load_persistently(self, dark_data_size_Gb=None, data_size_Gb=None): """Load dask data array in memory. Inputs ------ dark_data_size_Gb: float, optional size of dark to load in memory, in Gb data_size_Gb: float, optional size of data to load in memory, in Gb """ self.arr_dark, self.tid_dark = load_dssc_module(self.proposal, self.darkrun, self.module, drop_intra_darks=self.drop_intra_darks, persist=True, data_size_Gb=dark_data_size_Gb) self.arr, self.tid = load_dssc_module(self.proposal, self.run, self.module, drop_intra_darks=self.drop_intra_darks, persist=True, data_size_Gb=data_size_Gb) # make sure to rechunk the arrays self.arr = self.arr.rechunk(('auto', -1, -1, -1)) self.arr_dark = self.arr_dark.rechunk(('auto', -1, -1, -1))
[docs] def use_gpu(self): assert _can_use_gpu, 'Failed to import cupy' gpu_mem_gb = cp.cuda.Device().mem_info[1] / 2**30 if gpu_mem_gb < 30: print(f'Warning: GPU memory ({gpu_mem_gb}GB) may be insufficient') if self._using_gpu: return assert ( self.arr is not None and self.arr_dark is not None ), "Must load data before switching to GPU" if self.mask is not None: self.mask = cp.array(self.mask) # moving full data to GPU limit = 2**30 self.arr = da.array( cp.array(self.arr.compute()) ).rechunk(('auto', -1, -1, -1), block_size_limit=limit) self.arr_dark = da.array( cp.array(self.arr_dark.compute()) ).rechunk(('auto', -1, -1, -1), block_size_limit=limit) self._using_gpu = True
[docs] def set_mask(self, arr): """Set mask of bad pixels. Inputs ------ arr: either a boolean array of a DSSC module image or a list of bad pixel indices """ if type(arr) is not list: self.mask_idx = np.argwhere(arr == False).tolist() self.mask = arr else: self.mask_idx = arr mask = np.ones((128, 512), dtype=bool) for k in self.mask_idx: mask[k[0], k[1]] = False self.mask = mask if self._using_gpu: self.mask = cp.array(self.mask)
[docs] def get_mask(self): """Get the boolean array bad pixel of a DSSC module.""" return self.mask
[docs] def get_mask_idx(self): """Get the list of bad pixel indices.""" return self.mask_idx
[docs] def flat_field_guess(self, guess=None): """Set the flat-field guess parameter for the fit and returns it. Inputs ------ guess: a list of 8 floats, the 4 first to define the plane ax+by+cz+d=0 for 'n' beam and the 4 last for the 'p' beam in case mirror symmetry is disbaled """ if guess is not None: self.plane_guess_fit = guess return self.plane_guess_fit if self.plane_guess_fit is None: if self.use_hex: self.plane_guess_fit = [ -20, 0.0, 1.5, -0.5, 20, 0, 1.5, -0.5 ] else: self.plane_guess_fit = [ -0.2, -0.1, 1, -0.54, 0.2, -0.1, 1, -0.54] return self.plane_guess_fit
[docs] def set_flat_field(self, ff_params, ff_type='plane', prod_th=None, ratio_th=None): """Set the flat-field plane definition. Inputs ------ ff_params: list of parameters ff_type: string identifying the type of flat field normalization, default is 'plane'. """ self.ff_type = ff_type if type(ff_params) is not list: self.flat_field = ff_params.tolist() else: self.flat_field = ff_params if prod_th is not None: self.flat_field_prod_th = prod_th if ratio_th is not None: self.flat_field_ratio_th = ratio_th
[docs] def get_flat_field(self): """Get the flat-field plane definition.""" if self.flat_field is None: return None else: return np.array(self.flat_field)
[docs] def set_Fnl(self, Fnl): """Set the non-linear correction function.""" if isinstance(Fnl, list): self.Fnl = Fnl else: self.Fnl = Fnl.tolist()
[docs] def get_Fnl(self): """Get the non-linear correction function.""" if self.Fnl is None: return None else: if self._using_gpu: return cp.array(self.Fnl) else: return np.array(self.Fnl)
[docs] def save(self, path='./'): """Save the parameters as a JSON file. Inputs ------ path: str, where to save the file, default to './' """ v = {} v['proposal'] = self.proposal v['darkrun'] = self.darkrun v['run'] = self.run v['module'] = self.module v['gain'] = self.gain v['drop_intra_darks'] = self.drop_intra_darks v['mask'] = self.mask_idx v['mean_th'] = self.mean_th v['std_th'] = self.std_th v['rois'] = self.rois v['rois_th'] = self.rois_th v['ff_type'] = self.ff_type v['flat_field'] = self.flat_field v['flat_field_prod_th'] = self.flat_field_prod_th v['flat_field_ratio_th'] = self.flat_field_ratio_th v['plane_guess_fit'] = self.plane_guess_fit v['use_hex'] = self.use_hex v['force_mirror'] = self.force_mirror v['ff_alpha'] = self.ff_alpha v['ff_max_iter'] = self.ff_max_iter v['Fnl'] = self.Fnl v['nl_alpha'] = self.nl_alpha v['sat_level'] = self.sat_level v['nl_max_iter'] = self.nl_max_iter fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json' with open(path + fname, 'w') as f: json.dump(v, f) print(path + fname)
@classmethod
[docs] def load(cls, fname): """Load parameters from a JSON file. Inputs ------ fname: string, name a the JSON file to load """ with open(fname, 'r') as f: v = json.load(f) c = cls(v['proposal'], v['darkrun'], v['run'], v['module'], v['gain'], v['drop_intra_darks']) c.mean_th = v['mean_th'] c.std_th = v['std_th'] c.set_mask(v['mask']) c.rois = v['rois'] c.rois_th = v['rois_th'] if 'ff_type' not in v: v['ff_type'] = 'plane' c.set_flat_field(v['flat_field'], v['ff_type'], v['flat_field_prod_th'], v['flat_field_ratio_th']) c.plane_guess_fit = v['plane_guess_fit'] c.use_hex = v['use_hex'] c.force_mirror = v['force_mirror'] c.ff_alpha = v['ff_alpha'] c.ff_max_iter = v['ff_max_iter'] c.set_Fnl(v['Fnl']) c.nl_alpha = v['nl_alpha'] c.sat_level = v['sat_level'] c.nl_max_iter = v['nl_max_iter'] return c
[docs] def __str__(self): f = f'proposal:{self.proposal} darkrun:{self.darkrun} run:{self.run}' f += f' module:{self.module} gain:{self.gain} ph/bin\n' f += f'drop intra darks:{self.drop_intra_darks}\n' if self.mask_idx is not None: f += f'mean threshold:{self.mean_th} std threshold:{self.std_th}\n' f += f'mask:(#{len(self.mask_idx)}) {self.mask_idx}\n' else: f += 'mask:None\n' f += f'rois threshold: {self.rois_th}\n' f += f'rois: {self.rois}\n' f += f'flat-field type: {self.ff_type}\n' f += f'flat-field p: {self.flat_field} ' f += f'prod:{self.flat_field_prod_th} ' f += f'ratio:{self.flat_field_ratio_th}\n' f += f'plane guess fit: {self.plane_guess_fit}\n' f += f'use hexagons: {self.use_hex}\n' f += f'enforce mirror symmetry: {self.force_mirror}\n' f += f'ff alpha: {self.ff_alpha}, max. iter.: {self.ff_max_iter}\n' if self.Fnl is not None: f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n' f += f'nl alpha:{self.nl_alpha}, sat. level:{self.sat_level}, ' f += f' nl max. iter.:{self.nl_max_iter}' else: f += 'Fnl: None' return f
def ensure_on_host(arr): # load data back from GPU - if it was on GPU if hasattr(arr, "__cuda_array_interface__"): # avoid importing CuPy return arr.get() elif isinstance(arr, (da.Array,)): return arr.map_blocks(ensure_on_host) return arr # Hexagonal pixels related function def _get_pixel_pos(module): """Compute the pixel position on hexagonal lattice of DSSC module.""" # module pixel position dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)] g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos) # keeping only module 15 pixel X,Y position return g.get_pixel_positions()[module][:, :, :2]
[docs]def get_roi_pixel_pos(roi, params): """Compute fake or real pixel position of an roi from roi center. Inputs: ------- roi: dictionnary params: parameters Returns: -------- X, Y: 1-d array of pixel position. """ if params.use_hex: # DSSC pixel position on hexagonal lattice X = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 0] Y = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 1] else: nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl'] X = np.arange(nX)/100 Y = np.arange(nY)[:, np.newaxis]/100 # center of ROI is put to 0,0 X -= np.mean(X) Y -= np.mean(Y) return X, Y
def _get_pixel_corners(module): """Compute the pixel corners of DSSC module.""" # module pixel position dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)] g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos) # corners are in z,y,x oder so we rop z, flip x & y corners = g.to_distortion_array(allow_negative_xy=True) corners = corners[(module*128):((module+1)*128), :, :, 1:][:, :, :, ::-1] return corners def _get_pixel_hexagons(module): """Compute DSSC pixel hexagons for plotting. Parameters: ----------- module: either int, for the module number or a 2-d array of corners to get hexagons from Returns: -------- a 1-d list of hexagons where corners position are in mm """ hexes = [] if type(module) is int: corners = _get_pixel_corners(module) else: corners = module for y in range(corners.shape[0]): for x in range(corners.shape[1]): c = 1e3*corners[y, x, :, :] # convert to mm hexes.append(Polygon(c)) return hexes def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05): """Add a colobar on a new axes so it match the plot size. Inputs ------ im: image plotted ax: axes on which the image was plotted loc: string, default 'right', location of the colorbar size: string, default '5%', proportion of the colobar with respect to the plotted image pad: float, default 0.05, pad width between plot and colorbar """ from mpl_toolkits.axes_grid1 import make_axes_locatable fig = ax.figure divider = make_axes_locatable(ax) cax = divider.append_axes(loc, size=size, pad=pad) cbar = fig.colorbar(im, cax=cax) return cbar # dark related functions
[docs]def bad_pixel_map(params): """Compute the bad pixels map. Inputs ------ params: parameters Returns ------- bad pixel map """ assert params.arr_dark is not None, "Data not loaded" # compute mean and std dark_mean = params.arr_dark.mean(axis=(0, 1)).compute() dark_std = params.arr_dark.std(axis=(0, 1)).compute() mask = np.ones_like(dark_mean) if params.mean_th[0] is not None: mask *= dark_mean >= params.mean_th[0] if params.mean_th[1] is not None: mask *= dark_mean <= params.mean_th[1] if params.std_th[0] is not None: mask *= dark_std >= params.std_th[0] if params.std_th[1] is not None: mask *= dark_std >= params.std_th[1] print(f'# bad pixel: {int(128*512-mask.sum())}') return mask.astype(bool)
[docs]def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)): """Inspect dark run data and plot diagnostic. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) mean_th: tuple of threshold (low, high), default (None, None), to compute a mask of good pixels for which the mean dark value lie inside this range std_th: tuple of threshold (low, high), default (None, None), to compute a mask of bad pixels for which the dark std value lie inside this range Returns ------- fig: matplotlib figure """ # compute mean and std dark_mean = ensure_on_host(arr.mean(axis=(0, 1)).compute()) dark_std = ensure_on_host(arr.std(axis=(0, 1)).compute()) fig = plt.figure(figsize=(7, 2.7)) gs = fig.add_gridspec(2, 4) ax1 = fig.add_subplot(gs[0, 1:]) ax1.set_xticklabels([]) ax1.set_yticklabels([]) ax11 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 1:]) ax2.set_xticklabels([]) ax2.set_yticklabels([]) ax22 = fig.add_subplot(gs[1, 0]) vmin = np.percentile(dark_mean.flatten(), 2) vmax = np.percentile(dark_mean.flatten(), 98) im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax) ax1.invert_yaxis() ax1.set_aspect('equal') cbar1 = _add_colorbar(im1, ax=ax1, size='2%') cbar1.ax.set_ylabel('dark mean') ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1), range=(vmin/2, vmax*2)) if mean_th[0] is not None: ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--') if mean_th[1] is not None: ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--') ax11.set_yscale('log') vmin = np.percentile(dark_std.flatten(), 2) vmax = np.percentile(dark_std.flatten(), 98) im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax) ax2.invert_yaxis() ax2.set_aspect('equal') cbar2 = _add_colorbar(im2, ax=ax2, size='2%') cbar2.ax.set_ylabel('dark std') ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2)) if std_th[0] is not None: ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--') if std_th[1] is not None: ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--') ax22.set_yscale('log') return fig
# histogram related functions
[docs]def histogram_module(arr, mask=None): """Compute a histogram of the 9 bits raw pixel values over a module. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) mask: optional bad pixel mask Returns ------- histogram """ if mask is not None: w = da.repeat(da.repeat(da.array(mask[None, None, :, :]), arr.shape[1], axis=1), arr.shape[0], axis=0) w = w.rechunk(arr.chunks) return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute() else: return da.bincount(arr.ravel(), minlength=512).compute()
[docs]def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False): """Compute and plot a histogram of the 9 bits raw pixel values. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y) mask: optional bad pixel mask extra_lines: boolean, default False, plot extra lines at period values Returns ------- (h, hd): histogram of arr, arr_dark figure """ from matplotlib.ticker import MultipleLocator f = plt.figure(figsize=(6, 3)) ax = plt.gca() h = ensure_on_host(histogram_module(arr, mask=mask)) Sum_h = np.sum(h) ax.plot(np.arange(2**9), h/Sum_h, marker='o', ms=3, markerfacecolor='none', lw=1) if arr_dark is not None: hd = ensure_on_host(histogram_module(arr_dark, mask=mask)) Sum_hd = np.sum(hd) ax.plot(np.arange(2**9), hd/Sum_hd, marker='o', ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5) else: hd = None if extra_lines: for k in range(50, 271): if not (k - 2) % 8: ax.axvline(k, c='k', alpha=0.5, ls='--') if not (k - 3) % 16: ax.axvline(k, c='g', alpha=0.3, ls='--') if not (k - 7) % 32: ax.axvline(k, c='r', alpha=0.3, ls='--') ax.axvline(271, c='C1', alpha=0.5, ls='--') ax.set_xlim([0, 2**9-1]) ax.set_yscale('log') ax.xaxis.set_minor_locator(MultipleLocator(10)) ax.set_xlabel('DSSC pixel value') ax.set_ylabel('count frequency') return (h, hd), f
# rois related function
[docs]def find_rois(data_mean, threshold, extended=False): """Find rois from 3 beams configuration. Inputs ------ data_mean: dark corrected average image threshold: threshold value to find beams extended: boolean, True to define additional ASICS based rois Returns ------- rois: dictionnary of rois """ # compute vertical and horizontal projection pX = data_mean.mean(axis=0) pX = pX[:256] # half the ladder since there is a gap in the middle pY = data_mean.mean(axis=1) pX = pX/np.max(pX) pY = pY/np.max(pY) # along X lowX = int(np.argmax(pX > threshold) - 1) # 1st occurrence returned highX = int(pX.shape[0] - np.argmax(pX[::-1] > threshold)) # last occ. returned midX = int(0.5*(lowX+highX)) leftX2 = int(np.argmax(pX[lowX+5:midX-5] < threshold)) + lowX + 5 midX2 = int(np.argmax(pX[midX+5:highX-5] < threshold)) + midX + 5 midX1 = int(midX - 5 - np.argmax(pX[midX-5:lowX+5:-1] < threshold)) rightX1 = int(highX - 5 - np.argmax(pX[highX-5:midX+5:-1] < threshold)) # along Y lowY = int(np.argmax(pY > threshold) - 1) # 1st occurrence returned highY = int(pY.shape[0] - np.argmax(pY[::-1] > threshold)) # last occ. returned # define rois rois = {} # beam roi rois['n'] = {'xl': lowX, 'xh': leftX2, 'yl': lowY, 'yh': highY} rois['0'] = {'xl': midX1, 'xh': midX2, 'yl': lowY, 'yh': highY} rois['p'] = {'xl': rightX1, 'xh': highX, 'yl': lowY, 'yh': highY} # saturation roi rois['sat'] = {'xl': lowX, 'xh': highX, 'yl': lowY, 'yh': highY} if extended: # baseline correction rois for k in [0, 1, 2, 3]: rois[f'b{k}'] = {'xl': k*64, 'xh': (k+1)*64, 'yl': 0, 'yh': lowY} for k in [8, 9, 10, 11]: rois[f'b{k}'] = {'xl': (k-8)*64, 'xh': (k+1-8)*64, 'yl': highY, 'yh': 128} # ASICs splitted beam roi rois['0X'] = {'xl': lowX, 'xh': 1*64, 'yl': lowY, 'yh': 64} rois['1X1'] = {'xl': 64, 'xh': leftX, 'yl': lowY, 'yh': 64} rois['1X2'] = {'xl': leftX, 'xh': 2*64, 'yl': lowY, 'yh': 64} rois['2X1'] = {'xl': 2*64, 'xh': rightX, 'yl': lowY, 'yh': 64} rois['2X2'] = {'xl': rightX, 'xh': 3*64, 'yl': lowY, 'yh': 64} rois['3X'] = {'xl': 3*64, 'xh': highX, 'yl': lowY, 'yh': 64} rois['8X'] = {'xl': lowX, 'xh': 1*64, 'yl': 64, 'yh': highY} rois['9X1'] = {'xl': 64, 'xh': leftX, 'yl': 64, 'yh': highY} rois['9X2'] = {'xl': leftX, 'xh': 2*64, 'yl': 64, 'yh': highY} rois['10X1'] = {'xl': 2*64, 'xh': rightX, 'yl': 64, 'yh': highY} rois['10X2'] = {'xl': rightX, 'xh': 3*64, 'yl': 64, 'yh': highY} rois['11X'] = {'xl': 3*64, 'xh': highX, 'yl': 64, 'yh': highY} return rois
[docs]def find_rois_from_params(params): """Find rois from 3 beams configuration. Inputs ------ params: parameters Returns ------- rois: dictionnary of rois """ assert params.arr_dark is not None, "Data not loaded" dark = average_module(params.arr_dark).compute() assert params.arr is not None, "Data not loaded" data = average_module(params.arr, dark=dark).compute() data_mean = data.mean(axis=0) # mean over pulseId threshold = params.rois_th return find_rois(data_mean, threshold)
[docs]def inspect_rois(data_mean, rois, threshold=None, allrois=False): """Find rois from 3 beams configuration from mean module image. Inputs ------ data_mean: mean module image threshold: float, default None, threshold value used to detect beams boundaries allrois: boolean, default False, plot all rois defined in rois or only the main ones (['n', '0', 'p']) Returns ------- matplotlib figure """ # compute vertical and horizontal projection pX = data_mean.mean(axis=0) pX = pX[:256] # half the ladder since there is a gap in the middle pY = data_mean.mean(axis=1) pX = pX/np.max(pX) pY = pY/np.max(pY) # Set up the axes with gridspec fig = plt.figure(figsize=(5, 3)) grid = plt.GridSpec(2, 2, width_ratios=(1, 4), height_ratios=(2, 1), # left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.05, hspace=0.05, figure=fig) main_ax = fig.add_subplot(grid[0, 1]) y = fig.add_subplot(grid[0, 0], xticklabels=[], sharey=main_ax) x = fig.add_subplot(grid[1, 1], yticklabels=[], sharex=main_ax) # scatter points on the main axes Xs = np.arange(len(pX)) Ys = np.arange(len(pY)) main_ax.pcolormesh(Xs, Ys, np.flipud(data_mean[:, :256]), cmap='Greys_r', vmin=0, vmax=np.percentile(data_mean[:, :256], 99)) main_ax.set_aspect('equal') from matplotlib.patches import Rectangle roi = rois['n'] main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']), roi['xh'] - roi['xl'], roi['yh'] - roi['yl'], alpha=0.3, color='b')) roi = rois['0'] main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']), roi['xh'] - roi['xl'], roi['yh'] - roi['yl'], alpha=0.3, color='g')) roi = rois['p'] main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']), roi['xh'] - roi['xl'], roi['yh'] - roi['yl'], alpha=0.3, color='r')) x.plot(Xs, pX) x.invert_yaxis() if threshold is not None: x.axhline(threshold, c='k', alpha=.5, ls='--') x.axvline(rois['n']['xl'], c='b', alpha=.3) x.axvline(rois['n']['xh'], c='b', alpha=.3) x.axvline(rois['0']['xl'], c='g', alpha=.3) x.axvline(rois['0']['xh'], c='g', alpha=.3) x.axvline(rois['p']['xl'], c='r', alpha=.3) x.axvline(rois['p']['xh'], c='r', alpha=.3) y.plot(pY, np.arange(len(pY)-1, -1, -1)) y.invert_xaxis() if threshold is not None: y.axvline(threshold, c='k', alpha=.5, ls='--') y.axhline(127-rois['p']['yl'], c='r', alpha=.5) y.axhline(127-rois['p']['yh'], c='r', alpha=.5) return fig
# Flat-field related functions def _plane_flat_field(p, roi, params): """Compute the p plane over the given roi. Given the plane parameters p, compute the plane over the roi size. Parameters ---------- p: a vector of a, b, c, d plane parameter with the plane given by ax+ by + cz + d = 0 roi: a dictionnary roi['yh', 'yl', 'xh', 'xl'] params: parameters Returns ------- the plane field given by p evaluated on the roi extend. """ a, b, c, d = p X, Y = get_roi_pixel_pos(roi, params) Z = -(a*X + b*Y + d)/c return Z
[docs]def compute_flat_field_correction(rois, params, plot=False): if params.ff_type == 'plane': return compute_plane_flat_field_correction(rois, params, plot) elif params.ff_type == 'polyline': return compute_polyline_flat_field_correction(rois, params, plot) else: raise ValueError(f'Uknown flat field type {params.ff_type}')
def compute_plane_flat_field_correction(rois, params, plot=False): """Compute the plane-field correction on beam rois. Inputs ------ rois: dictionnary of beam rois['n', '0', 'p'] params: parameters plot: boolean, True by default, diagnostic plot Returns ------- numpy 2D array of the flat-field correction evaluated over one DSSC ladder (2 sensors) """ flat_field = np.ones((128, 512)) plane = params.get_flat_field() force_mirror = params.force_mirror r = rois['n'] flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ _plane_flat_field(plane[:4], r, params) r = rois['p'] if force_mirror: a, b, c, d = plane[:4] flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ _plane_flat_field([-a, b, c, d], r, params) else: flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \ _plane_flat_field(plane[4:], r, params) if plot: f, ax = plt.subplots(1, 1, figsize=(6, 2)) img = ax.pcolormesh( np.flipud(flat_field[:, :256]), cmap='Greys_r') f.colorbar(img, ax=[ax], label='amplitude') ax.set_xlabel('px') ax.set_ylabel('px') ax.set_aspect('equal') return flat_field def initialize_polyline_ff_correction(avg, rois, params, plot=False): """Initialize the polyline flat field correction. Inputs ------ avg: 2D array, average module image rois: dictionnary of ROIs. plot: boolean, plot initialized polyline versus data projection Returns ------- fig: handle to figure or None """ refn = avg[rois['n']['yl']:rois['n']['yh'], rois['n']['xl']:rois['n']['xh']] refp = avg[rois['p']['yl']:rois['p']['yh'], rois['p']['xl']:rois['p']['xh']] mid = avg[rois['0']['yl']:rois['0']['yh'], rois['0']['xl']:rois['0']['xh']] mref = 0.5*(refn + refp) inv_signal = mref/mid # normalization H_projection = inv_signal[:, :].mean(axis=0) x = np.arange(0, len(H_projection)) H_z = np.polyfit(x, H_projection, 6) H_p = np.poly1d(H_z) V_projection = (inv_signal/H_p(x))[:, :].mean(axis=1) y = np.arange(0, len(V_projection)) V_z = np.polyfit(y, V_projection, 6) if plot: fig, axs = plt.subplots(2, 1, figsize=(4,6)) axs[0].plot(x, H_projection, label='data (n+p)/2x0') axs[0].plot(x, H_p(x), label='poly') axs[0].legend() axs[0].set_xlabel('x (px)') axs[0].set_ylabel('H projection') axs[1].plot(y, V_projection, label='data (n+p)/2x0') V_p = np.poly1d(V_z) axs[1].plot(y, V_p(y), label='poly') axs[1].legend() axs[1].set_xlabel('y (px)') axs[1].set_ylabel('V projection') else: fig = None # scaling on polynom coefficients for better fitting ff = np.array([H_z/np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]), V_z/np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])]) params.set_flat_field(ff.flatten()) params.ff_type = 'polyline' return fig def compute_polyline_flat_field_correction(rois, params, plot=False): """Compute the 1D polyline field correction on beam rois. Inputs ------ rois: dictionnary of beam rois['n', '0', 'p'] params: parameters plot: boolean, True by default, diagnostic plot Returns ------- numpy 2D array of the flat-field correction evaluated over one DSSC ladder (2 sensors) """ flat_field = np.ones((128, 512)) z = np.array(params.get_flat_field()).reshape((2, -1)) H_z = z[0, :] V_z = z[1, :] coeffs = np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]) H_p = np.poly1d(H_z*coeffs) coeffs = np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0]) V_p = np.poly1d(V_z*coeffs) n = rois['n'] p = rois['p'] wn = n['xh']-n['xl'] wp = p['xh']-p['xl'] assert wn == wp, (\ f"For polyline flat field normalization, both 'n' and 'p' ROIs " f"must have the same width {wn} and {wp}px" ) x = np.arange(wn) wn = n['yh']-n['yl'] y = np.arange(wn) norm = V_p(y)[:, np.newaxis]*H_p(x) n_int = flat_field[n['yl']:n['yh'], n['xl']:n['xh']] flat_field[n['yl']:n['yh'], n['xl']:n['xh']] = \ norm*n_int p_int = flat_field[p['yl']:p['yh'], p['xl']:p['xh']] flat_field[p['yl']:p['yh'], p['xl']:p['xh']] = \ norm*p_int # not the mirror if plot: f, ax = plt.subplots(1, 1, figsize=(6, 2)) img = ax.pcolormesh( np.flipud(flat_field[:, :256]), cmap='Greys_r') f.colorbar(img, ax=[ax], label='amplitude') ax.set_xlabel('px') ax.set_ylabel('px') ax.set_aspect('equal') return flat_field
[docs]def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. Inputs ------ avg: module average image with no saturated shots for the flat-field determination rois: dictionnary or ROIs prod_th, ratio_th: tuple of floats for low and high threshold on product and ratio vmin: imshow vmin level, default None will use 5 percentile value vmax: imshow vmax level, default None will use 99.8 percentile value Returns ------- fig: matplotlib figure plotted domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order """ if vmin is None: vmin = np.percentile(avg, 5) if vmax is None: vmax = np.percentile(avg, 99.8) fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9)) img_rois = {} centers = {} for k, r in enumerate(['n', '0', 'p']): roi = rois[r] centers[r] = np.array([(roi['yl'] + roi['yh'])//2, (roi['xl'] + roi['xh'])//2]) d = '0' roi = rois[d] for k, r in enumerate(['n', '0', 'p']): img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[ roi['yl']:roi['yh'], roi['xl']:roi['xh']] im = axs[0, k].imshow(img_rois[r], vmin=vmin, vmax=vmax) n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th) prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4 for k, r in enumerate(['n', '0', 'p']): v = img_rois[r]*img_rois['0'] if prod_vmin is None: prod_vmin = np.percentile(v, .5) prod_vmax = np.percentile(v, 20) # we look for low intensity region im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma') axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2)) v = img_rois[r]/img_rois['0'] if ratio_vmin is None: ratio_vmin = np.percentile(v, 5) ratio_vmax = np.percentile(v, 99.8) im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax, cmap='RdBu_r') axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2)) cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal") cbar.ax.set_xlabel('data mean') cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal") cbar.ax.set_xlabel('product') cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal") cbar.ax.set_xlabel('ratio') # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}') domain = (n_m, p_m) return fig, domain
[docs]def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None): warnings.warn("This method is depreciated, use inspect_ff_fitting instead") return inspect_ff_fitting(avg, rois, domain, vmin, vmax)
def inspect_ff_fitting(avg, rois, domain=None, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. Inputs ------ avg: module average image with no saturated shots for the flat-field determination rois: dictionnary of rois domain: list of domain mask for the -1st and +1st order vmin: imshow vmin level, default None will use 5 percentile value vmax: imshow vmax level, default None will use 99.8 percentile value Returns ------- fig: matplotlib figure plotted """ if vmin is None: vmin = np.percentile(avg, 5) if vmax is None: vmax = np.percentile(avg, 99.8) fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6)) img_rois = {} centers = {} for k, r in enumerate(['n', '0', 'p']): roi = rois[r] centers[r] = np.array([(roi['yl'] + roi['yh'])//2, (roi['xl'] + roi['xh'])//2]) d = '0' roi = rois[d] for k, r in enumerate(['n', '0', 'p']): img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[ roi['yl']:roi['yh'], roi['xl']:roi['xh']] im = axs[0, k].imshow(img_rois[r], vmin=vmin, vmax=vmax) for k, r in enumerate(['n', '0', 'p']): v = img_rois[r]/img_rois['0'] im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r') if domain is not None: n_m, p_m = domain axs[1, 0].contour(n_m) axs[1, 2].contour(p_m) cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal") cbar.ax.set_xlabel('data mean') cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal") cbar.ax.set_xlabel('ratio') # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}') return fig def inspect_ff_fitting_sk(avg, rois, ff, domain=None, vmin=None, vmax=None): """Extract beams roi from average image and compute the ratio. Inputs ------ avg: module average image with no saturated shots for the flat-field determination rois: dictionnary of rois ff: 2D array, flat field normalization domain: list of domain mask for the -1st and +1st order vmin: imshow vmin level, default None will use 5 percentile value vmax: imshow vmax level, default None will use 99.8 percentile value Returns ------- fig: matplotlib figure plotted """ if vmin is None: vmin = np.percentile(avg, 5) if vmax is None: vmax = np.percentile(avg, 99.8) refn = avg[rois['n']['yl']:rois['n']['yh'], rois['n']['xl']:rois['n']['xh']] refp = avg[rois['p']['yl']:rois['p']['yh'], rois['p']['xl']:rois['p']['xh']] mid = avg[rois['0']['yl']:rois['0']['yh'], rois['0']['xl']:rois['0']['xh']] mref = 0.5*(refn + refp) ffn = ff[rois['n']['yl']:rois['n']['yh'], rois['n']['xl']:rois['n']['xh']] ffp = ff[rois['p']['yl']:rois['p']['yh'], rois['p']['xl']:rois['p']['xh']] ffmid = ff[rois['0']['yl']:rois['0']['yh'], rois['0']['xl']:rois['0']['xh']] np_norm = 0.5*(ffn+ffp) mid_norm = ffmid fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(8, 4)) im = axs[0, 0].imshow(mref) axs[0, 0].set_title('(n+p)/2') fig.colorbar(im, ax=axs[0, 0]) im = axs[1, 0].imshow(mid) axs[1, 0].set_title('0') fig.colorbar(im, ax=axs[1, 0]) im = axs[2, 0].imshow(mid/mref-1, cmap='RdBu_r', vmin=-1, vmax=1) axs[2, 0].set_title('2x0/(n+p) - 1') fig.colorbar(im, ax=axs[2, 0]) im = axs[0, 1].imshow(np_norm) axs[0, 1].set_title('norm: (n+p)/2') fig.colorbar(im, ax=axs[0, 1]) im = axs[1, 1].imshow(mid_norm) axs[1, 1].set_title('norm: 0') fig.colorbar(im, ax=axs[1, 1]) im = axs[2, 1].imshow(mid_norm/np_norm-1, cmap='RdBu_r', vmin=-1, vmax=1) axs[2, 1].set_title('norm: 2x0/(n+p) - 1') fig.colorbar(im, ax=axs[2, 1]) im = axs[0, 2].imshow(mref/np_norm) axs[0, 2].set_title('(n+p)/2 /norm') fig.colorbar(im, ax=axs[0, 2]) im = axs[1, 2].imshow(mid/mid_norm) axs[1, 2].set_title('0 /norm') fig.colorbar(im, ax=axs[1, 2]) im = axs[2, 2].imshow((mid/mid_norm)/(mref/np_norm)-1, cmap='RdBu_r', vmin=-1, vmax=1) axs[2, 2].set_title('2x0/(n+p) - 1 /norm') fig.colorbar(im, ax=axs[2, 2]) # fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}') return fig
[docs]def plane_fitting_domain(avg, rois, prod_th, ratio_th): """Extract beams roi, compute their ratio and the domain. Inputs ------ avg: module average image with no saturated shots for the flat-field determination rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0' as the reference beam in the middle prod_th: float tuple, low and hight threshold level to determine the plane fitting domain on the product image of the orders ratio_th: float tuple, low and high threshold level to determine the plane fitting domain on the ratio image of the orders Returns ------- n: img ratio 'n'/'0' n_m: mask where the the product 'n'*'0' is higher than 5 indicting that the img ratio 'n'/'0' is defined p: img ratio 'p'/'0' p_m: mask where the the product 'p'*'0' is higher than 5 indicting that the img ratio 'p'/'0' is defined """ centers = {} for k in ['n', '0', 'p']: r = rois[k] centers[k] = np.array([(r['yl'] + r['yh'])//2, (r['xl'] + r['xh'])//2]) k = 'n' r = rois[k] num = avg[r['yl']:r['yh'], r['xl']:r['xh']] d = '0' denom = np.roll(avg, tuple(centers[k] - centers[d]))[ r['yl']:r['yh'], r['xl']:r['xh']] n = num/denom prod = num*denom n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) * (n > ratio_th[0]) * (n < ratio_th[1])) n_m[~np.isfinite(n)] = 0 n[~np.isfinite(n)] = 0 k = 'p' r = rois[k] num = avg[r['yl']:r['yh'], r['xl']:r['xh']] d = '0' denom = np.roll(avg, tuple(centers[k] - centers[d]))[ r['yl']:r['yh'], r['xl']:r['xh']] p = num/denom prod = num*denom p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) * (p > ratio_th[0]) * (p < ratio_th[1])) p_m[~np.isfinite(p)] = 0 p[~np.isfinite(p)] = 0 return n, n_m, p, p_m
[docs]def plane_fitting(params): """Fit the plane flat-field normalization. Inputs ------ params: parameters Returns ------- res: the minimization result. The fitted vector res.x = [a, b, c, d] defines the plane as a*x + b*y + c*z + d = 0 """ assert params.arr_dark is not None, "Data not loaded" dark = average_module(params.arr_dark).compute() assert params.arr is not None, "Data not loaded" data = average_module(params.arr, dark=dark, ret='mean', mask=params.mask, sat_roi=params.rois['sat'], sat_level=params.sat_level).compute() data_mean = data.mean(axis=0) # mean over pulseId n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois, params.flat_field_prod_th, params.flat_field_ratio_th) def _crit(x): """Fitting criteria for the plane field normalization. Inputs ------ x: 2 vector [a, b, c, d] concatenated defining the plane as a*x + b*y + c*z + d = 0 """ a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x num_n = a_n**2 + b_n**2 + c_n**2 roi = params.rois['n'] X, Y = get_roi_pixel_pos(roi, params) d0_2 = np.sum(n_m*(a_n*X + b_n*Y + c_n*n + d_n)**2)/num_n num_p = a_p**2 + b_p**2 + c_p**2 roi = params.rois['p'] X, Y = get_roi_pixel_pos(roi, params) if params.force_mirror: d2_2 = np.sum(p_m*(-a_n*X + b_n*Y + c_n*p + d_n)**2)/num_n else: d2_2 = np.sum(p_m*(a_p*X + b_p*Y + c_p*p + d_p)**2)/num_p return 1e3*(d2_2 + d0_2) res = minimize(_crit, params.flat_field_guess()) return res
[docs]def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois, mask, sat_level=511): """Criteria for the ff_refine_fit. Inputs ------ p: ff plane params: parameters arr_dark: dark data arr: data tid: train id of arr data rois: ['n', '0', 'p', 'sat'] rois mask: mask fo good pixels sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- sum of standard deviation on binned 0th order intensity """ params.set_flat_field(p) ff = compute_flat_field_correction(rois, params) if np.any(ff < 0.0): bad = 1e6 else: bad = 0.0 data = process(None, arr_dark, arr, tid, rois, mask, ff, sat_level, params._using_gpu) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) rn = xas(d, 40, Iokey='0', Itkey='n', nrjkey='0', fluorescence=True) rp = xas(d, 40, Iokey='0', Itkey='p', nrjkey='0', fluorescence=True) rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0', fluorescence=True) err_sigma = (np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA']) + np.nansum(rd['sigmaA'])) err_mean = ((1.0 - np.nanmean(rn['muA']))**2 + (1.0 - np.nanmean(rp['muA']))**2 + (1.0 - np.nanmean(rd['muA']))**2) return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
def ff_refine_crit_sk(p, alpha, params, arr_dark, arr, tid, rois, mask, sat_level=511): """Criteria for the ff_refine_fit, combining 'n' and 'p' as reference. Inputs ------ p: ff plane params: parameters arr_dark: dark data arr: data tid: train id of arr data rois: ['n', '0', 'p', 'sat'] rois mask: mask fo good pixels sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- sum of standard deviation on binned 0th order intensity """ params.set_flat_field(p, params.ff_type) ff = compute_flat_field_correction(rois, params) if np.any(ff < 0.0): bad = 1e6 else: bad = 0.0 data = process(None, arr_dark, arr, tid, rois, mask, ff, sat_level, params._using_gpu) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) r = xas(d, 40, Iokey='np_mean_sk', Itkey='0', nrjkey='0', fluorescence=True) err_sigma = np.nansum(r['sigmaA']) err_mean = (1.0 - np.nanmean(r['muA']))**2 return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
[docs]def ff_refine_fit(params, crit=ff_refine_crit): """Refine the flat-field fit by minimizing data spread. Inputs ------ params: parameters Returns ------- res: scipy minimize result. res.x is the optimized parameters fitrres: iteration index arrays of criteria results for [alpha=0, alpha, alpha=1] """ # load data assert params.arr is not None, "Data not loaded" assert params.arr_dark is not None, "Data not loaded" # we only need few rois fitrois = {} for k in ['n', '0', 'p', 'sat']: fitrois[k] = params.rois[k] p0 = params.get_flat_field() if p0 is None: # flat field was not yet fitted p0 = params.flat_field_guess() fixed_p = (params.ff_alpha, params, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), params.sat_level) def fit_callback(x): if not hasattr(fit_callback, "counter"): fit_callback.counter = 0 # it doesn't exist yet, so initialize it fit_callback.start = time.monotonic() fit_callback.res = [] now = time.monotonic() time_delta = datetime.timedelta(seconds=now-fit_callback.start) fit_callback.counter += 1 temp = list(fixed_p) Jalpha = crit(x, *temp) temp[0] = 0 J0 = crit(x, *temp) temp[0] = 1 J1 = crit(x, *temp) fit_callback.res.append([J0, Jalpha, J1]) print(f'{fit_callback.counter-1}: {time_delta} ' f'(reg. term: {J0}, {Jalpha}, err. term: {J1}), {x}') return False fit_callback(p0) res = minimize(crit, p0, fixed_p, options={'disp': True, 'maxiter': params.ff_max_iter}, callback=fit_callback) return res, fit_callback.res
# non-linearity related functions
[docs]def nl_domain(N, low, high): """Create the input domain where the non-linear correction defined. Inputs ------ N: integer, number of control points or intervals low: input values below or equal to low will not be corrected high: input values higher or equal to high will not be corrected Returns ------- array of 2**9 integer values with N segments """ x = np.arange(2**9) vx = x.copy() eps = 1e-5 vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1) vx[x <= low] = 0 vx[x >= high] = 0 return vx
[docs]def nl_lut(domain, dy): """Compute the non-linear correction. Inputs ------ domain: input domain where dy is defined. For zero no correction is defined. For non-zero value x, dy[x] is applied. dy: a vector of deviation from linearity on control point homogeneously dispersed over 9 bits. Returns ------- F_INL: default None, non linear correction function given as a lookup table with 9 bits integer input """ x = np.arange(2**9) ndy = np.insert(dy, 0, 0) # add zero to dy f = x + ndy[domain] return f
[docs]def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511, use_gpu=False): """Criteria for the non linear correction. Inputs ------ p: vector of dy non linear correction domain: domain over which the non linear correction is defined alpha: float, coefficient scaling the cost of the correction function in the criterion arr_dark: dark data arr: data tid: train id of arr data rois: ['n', '0', 'p', 'sat'] rois mask: mask fo good pixels flat_field: zone plate flat-field correction sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of error squared from a transmission of 1.0 and err2 is the sum of the square of the deviation from the ideal detector response. """ Fmodel = nl_lut(domain, p) data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark, arr, tid, rois, mask, flat_field, sat_level, use_gpu) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(), methods=['weighted']) err_1 = 1e8*v_1['weighted']['s']**2 v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(), methods=['weighted']) err_2 = 1e8*v_2['weighted']['s']**2 err_a = np.sum((Fmodel-np.arange(2**9))**2) return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a
[docs]def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511, use_gpu=False): """Non linear correction criteria, combining 'n' and 'p' as reference. Inputs ------ p: vector of dy non linear correction domain: domain over which the non linear correction is defined alpha: float, coefficient scaling the cost of the correction function in the criterion arr_dark: dark data arr: data tid: train id of arr data rois: ['n', '0', 'p', 'sat'] rois mask: mask fo good pixels flat_field: zone plate flat-field correction sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- (1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of error squared from a transmission of 1.0 and err2 is the sum of the square of the deviation from the ideal detector response. """ Fmodel = nl_lut(domain, p) data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark, arr, tid, rois, mask, flat_field, sat_level, use_gpu) # drop saturated shots d = data.where(data['sat_sat'] == False, drop=True) v = snr(d['np_mean_sk'].values.flatten(), d['0'].values.flatten(), methods=['weighted']) err = 1e8*v['weighted']['s']**2 err_a = np.sum((Fmodel-np.arange(2**9))**2) return (1.0 - alpha)*err + alpha*err_a
[docs]def nl_fit(params, domain, ff=None, crit=None): """Fit non linearities correction function. Inputs ------ params: parameters domain: array of index ff: array, flat field correction crit: function, criteria function Returns ------- res: scipy minimize result. res.x is the optimized parameters fitrres: iteration index arrays of criteria results for [alpha=0, alpha, alpha=1] """ # load data assert params.arr is not None, "Data not loaded" assert params.arr_dark is not None, "Data not loaded" # we only need few rois fitrois = {} for k in ['n', '0', 'p', 'sat']: fitrois[k] = params.rois[k] # p0 N = np.unique(domain).shape[0] - 1 p0 = np.array([0]*N) # flat flat_field if ff is None: ff = compute_flat_field_correction(params.rois, params) if crit is None: crit = nl_crit fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) def fit_callback(x): if not hasattr(fit_callback, "counter"): fit_callback.counter = 0 # it doesn't exist yet, so initialize it fit_callback.start = time.monotonic() fit_callback.res = [] now = time.monotonic() time_delta = datetime.timedelta(seconds=now-fit_callback.start) fit_callback.counter += 1 temp = list(fixed_p) Jalpha = crit(x, *temp) temp[1] = 0 J0 = crit(x, *temp) temp[1] = 1 J1 = crit(x, *temp) fit_callback.res.append([J0, Jalpha, J1]) print(f'{fit_callback.counter-1}: {time_delta} ' f'({J0}, {Jalpha}, {J1}), {x}') return False fit_callback(p0) res = minimize(crit, p0, fixed_p, options={'disp': True, 'maxiter': params.nl_max_iter}, callback=fit_callback) return res, fit_callback.res
[docs]def inspect_nl_fit(res_fit): """Plot the progress of the fit. Inputs ------ res_fit: Returns ------- matplotlib figure """ r = np.array(res_fit) f = plt.figure(figsize=(6, 4)) ax = f.gca() ax2 = plt.twinx() ax.plot(1.0/np.sqrt(1e-8*r[:, 0]), c='C0') ax2.plot(r[:, 2], c='C1', ls='-.') ax.set_xlabel('# iteration') ax.set_ylabel('SNR', color='C0') ax2.set_ylabel('correction cost', color='C1') ax.set_yscale('log') ax2.set_yscale('log') return f
[docs]def snr(sig, ref, methods=None, verbose=False): """ Compute mean, std and SNR from transmitted and I0 signals. Inputs ------ sig: 1D signal samples ref: 1D reference samples methods: None by default or list of strings to select which methods to use. Possible values are 'direct', 'weighted', 'diff'. In case of None, all methods will be calculated. verbose: booleand, if True prints calculated values Returns ------- dictionnary of [methods][value] where value is 'mu' for mean and 's' for standard deviation. """ if methods is None: methods = ['direct', 'weighted', 'diff'] w = ref x = sig/ref mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref) w = w[mask] sig = sig[mask] ref = ref[mask] x = x[mask] res = {} # direct mean and std if 'direct' in methods: mu = np.mean(x) s = np.std(x) if verbose: print(f'mu: {mu}, s: {s}, snr: {mu/s}') res['direct'] = {'mu': mu, 's':s} # weighted mean and std if 'weighted' in methods: wmu = np.sum(sig)/np.sum(ref) v1 = np.sum(w) v2 = np.sum(w**2) ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1)) if verbose: print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}') res['weighted'] = {'mu': wmu, 's':ws} # noise from diff if 'diff' in methods: dmu = np.mean(x) ds = np.std(np.diff(x))/np.sqrt(2) if verbose: print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}') res['diff'] = {'mu': dmu, 's':ds} return res
[docs]def inspect_Fnl(Fnl): """Plot the correction function Fnl. Inputs ------ Fnl: non linear correction function lookup table Returns ------- matplotlib figure """ x = np.arange(2**9) f = plt.figure(figsize=(6, 4)) plt.plot(x, Fnl - x) # plt.axvline(40, c='k', ls='--') # plt.axvline(280, c='k', ls='--') plt.xlabel('input value') plt.ylabel('output correction F(x)-x') plt.xlim([0, 511]) return f
[docs]def inspect_correction(params, gain=None): """Comparison plot of the different corrections. Inputs ------ params: parameters gain: float, default None, DSSC gain in ph/bin Returns ------- matplotlib figure """ # load data assert params.arr is not None, "Data not loaded" assert params.arr_dark is not None, "Data not loaded" # we only need few rois fitrois = {} for k in ['n', '0', 'p', 'sat']: fitrois[k] = params.rois[k] # flat flat_field plane_ff = params.get_flat_field() if plane_ff is None: plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0] ff = compute_flat_field_correction(params.rois, params) # non linearities Fnl = params.get_Fnl() if Fnl is None: Fnl = np.arange(2**9) xp = np if not params._using_gpu else cp # compute all levels of correction data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level, params._using_gpu) data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) # for conversion to nb of photons if gain is None: g = 1 else: g = gain scale = 1e-6 f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True) # nbins = np.linspace(0.01, 1.0, 100) photon_scale = None for k, d in enumerate([data, data_ff, data_ff_nl]): for l, (n, r) in enumerate([('n', '0'), ('p', '0'), ('n', 'p')]): if photon_scale is None: lower = 0 upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9) photon_scale = np.linspace(lower, upper, 150) good_d = d.where(d['sat_sat'] == False, drop=True) sat_d = d.where(d['sat_sat'], drop=True) snr_v = snr(good_d[n].values.flatten(), good_d[r].values.flatten(), verbose=True) m = snr_v['direct']['mu'] h, xedges, yedges, img = axs[l, k].hist2d( g*scale*good_d[r].values.flatten(), good_d[n].values.flatten()/good_d[r].values.flatten(), [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Blues', norm=LogNorm(vmin=0.2, vmax=200), # alpha=0.5 # make the plot looks ugly with lots of white lines ) h, xedges, yedges, img2 = axs[l, k].hist2d( g*scale*sat_d[r].values.flatten(), sat_d[n].values.flatten()/sat_d[r].values.flatten(), [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Reds', norm=LogNorm(vmin=0.2, vmax=200), # alpha=0.5 # make the plot looks ugly with lots of white lines ) v = snr_v['direct']['mu']/snr_v['direct']['s'] axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}', transform = axs[l, k].transAxes) v = snr_v['weighted']['mu']/snr_v['weighted']['s'] axs[l, k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}', transform = axs[l, k].transAxes) #axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') #axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') axs[l, k].set_ylim([0.95*m, 1.05*m]) for k in range(3): #for l in range(3): # axs[l, k].set_ylim([0.95, 1.05]) if gain: axs[2, k].set_xlabel('photons (10$^6$)') else: axs[2, k].set_xlabel('ADU (10$^6$)') f.colorbar(img, ax=axs, label='events') axs[0, 0].set_title('raw') axs[0, 1].set_title('flat-field') axs[0, 2].set_title('non-linear') axs[0, 0].set_ylabel(r'-1$^\mathrm{st}$/0$^\mathrm{th}$ order') axs[1, 0].set_ylabel(r'1$^\mathrm{st}$/0$^\mathrm{th}$ order') axs[2, 0].set_ylabel(r'-1$^\mathrm{st}$/1$^\mathrm{th}$ order') return f
[docs]def inspect_correction_sk(params, ff, gain=None): """Comparison plot of the different corrections, combining 'n' and 'p'. Inputs ------ params: parameters gain: float, default None, DSSC gain in ph/bin Returns ------- matplotlib figure """ # load data assert params.arr is not None, "Data not loaded" assert params.arr_dark is not None, "Data not loaded" # we only need few rois fitrois = {} for k in ['n', '0', 'p', 'sat']: fitrois[k] = params.rois[k] # flat flat_field #plane_ff = params.get_flat_field() #if plane_ff is None: # plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0] #ff = compute_flat_field_correction(params.rois, params) # non linearities Fnl = params.get_Fnl() if Fnl is None: Fnl = np.arange(2**9) xp = np if not params._using_gpu else cp # compute all levels of correction data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level, params._using_gpu) data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid, fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu) # for conversion to nb of photons if gain is None: g = 1 else: g = gain scale = 1e-6 f, axs = plt.subplots(1, 3, figsize=(8, 2.5), sharex=True) # nbins = np.linspace(0.01, 1.0, 100) photon_scale = None for k, d in enumerate([data, data_ff, data_ff_nl]): if photon_scale is None: lower = 0 upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9) photon_scale = np.linspace(lower, upper, 150) good_d = d.where(d['sat_sat'] == False, drop=True) sat_d = d.where(d['sat_sat'], drop=True) snr_v = snr(good_d['np_mean_sk'].values.flatten(), good_d['0'].values.flatten(), verbose=True) m = snr_v['direct']['mu'] h, xedges, yedges, img = axs[k].hist2d( g*scale*good_d['0'].values.flatten(), good_d['np_mean_sk'].values.flatten()/good_d['0'].values.flatten(), [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Blues', norm=LogNorm(vmin=0.2, vmax=200), # alpha=0.5 # make the plot looks ugly with lots of white lines ) h, xedges, yedges, img2 = axs[k].hist2d( g*scale*sat_d['0'].values.flatten(), sat_d['np_mean_sk'].values.flatten()/sat_d['0'].values.flatten(), [photon_scale, np.linspace(0.95, 1.05, 150)*m], cmap='Reds', norm=LogNorm(vmin=0.2, vmax=200), # alpha=0.5 # make the plot looks ugly with lots of white lines ) v = snr_v['direct']['mu']/snr_v['direct']['s'] axs[k].text(0.4, 0.15, f'SNR: {v:.0f}', transform = axs[k].transAxes) v = snr_v['weighted']['mu']/snr_v['weighted']['s'] axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}', transform = axs[k].transAxes) # axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--') # axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--') axs[k].set_ylim([0.95*m, 1.05*m]) for k in range(3): #for l in range(3): # axs[l, k].set_ylim([0.95, 1.05]) if gain: axs[k].set_xlabel('photons (10$^6$)') else: axs[k].set_xlabel('ADU (10$^6$)') f.colorbar(img, ax=axs, label='events') axs[0].set_title('raw') axs[1].set_title('flat-field') axs[2].set_title('non-linear') axs[0].set_ylabel(r'np_mean/0') return f
# data processing related functions
[docs]def load_dssc_module(proposalNB, runNB, moduleNB=15, subset=slice(None), drop_intra_darks=True, persist=False, data_size_Gb=None): """Load single module dssc data as dask array. Inputs ------ proposalNB: proposal number runNB: run number moduleNB: default 15, module number subset: default slice(None), subset of trains to load drop_intra_darks: boolean, default True, remove intra darks from the data persist: default False, load all data persistently in memory data_size_Gb: float, if persist is True, can optionaly restrict the amount of data loaded for dark data and run data in Gb Returns ------- arr: dask array of reshaped dssc data (trainId, pulseId, x, y) tid: array of train id number """ run = open_run(proposal=proposalNB, run=runNB) # DSSC source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf' key = 'image.data' arr = run[source, key][subset].dask_array() # fix 256 value becoming spuriously 0 instead arr[arr == 0] = 256 ppt = run[source, key][subset].data_counts() # ignore train with no pulses, can happen in burst mode acquisition ppt = ppt[ppt > 0] tid = ppt.index.values ppt = np.unique(ppt) assert ppt.shape[0] == 1, "number of pulses changed during the run" ppt = ppt[0] # reshape in trainId, pulseId, 2d-image arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3]) # drop intra darks if drop_intra_darks: arr = arr[:, ::2, :, :] # load data in memory if persist: if data_size_Gb is not None: # keep only xGb of data N = int(data_size_Gb*1024**3/(arr.shape[1]*128*512*2)) SLICE = slice(0, N) arr = arr[SLICE] tid = tid[SLICE] arr = arr.persist() return arr, tid
[docs]def average_module(arr, dark=None, ret='mean', mask=None, sat_roi=None, sat_level=300, F_INL=None): """Compute the average or std over a module. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) dark: default None, dark to be substracted ret: string, either 'mean' to compute the mean or 'std' to compute the standard deviation mask: default None, mask of bad pixels to ignore sat_roi: roi over which to check for pixel with values larger than sat_level to drop the image from the average or std sat_level: int, minimum pixel value for a pixel to be considered saturated F_INL: default None, non linear correction function given as a lookup table with 9 bits integer input Returns ------- average or standard deviation image """ # F_INL if F_INL is not None: narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype) else: narr = arr if mask is not None: narr = narr*mask if sat_roi is not None: not_sat = da.repeat( da.repeat( da.all( narr[ :, :, sat_roi["yl"] : sat_roi["yh"], sat_roi["xl"] : sat_roi["xh"], ] < sat_level, axis=[2, 3], keepdims=True, ), 128, axis=2, ), 512, axis=3, ) if dark is not None: narr = narr - dark if ret == 'mean': if sat_roi is not None: return da.average(narr, axis=0, weights=not_sat) else: return narr.mean(axis=0) elif ret == 'std': return narr.std(axis=0) else: raise ValueError(f'ret={ret} not supported')
[docs]def process_module(arr, tid, dark, rois, mask=None, sat_level=511, flat_field=None, F_INL=None, use_gpu=False): """Process one module and extract roi intensity. Inputs ------ arr: dask array of reshaped dssc data (trainId, pulseId, x, y) tid: array of train id number dark: pulse resolved dark image to remove rois: dictionnary of rois mask: default None, mask of ignored pixels sat_level: integer, default 511, at which level pixel begin to saturate flat_field: default None, flat-field correction F_INL: default None, non-linear correction function given as a lookup table with 9 bits integer input Returns ------- dataset of extracted pulse and train resolved roi intensities. """ # F_INL if F_INL is not None: narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype) else: narr = arr # apply mask if mask is not None: narr = narr*mask # crop rois r = {} rd = {} for n in rois.keys(): r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'], rois[n]['xl']:rois[n]['xh']] rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'], rois[n]['xl']:rois[n]['xh']] # find saturated shots r_sat = {} for n in rois.keys(): r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3)) # TODO: flat-field should not be applied on intra darks # # change flat-field dimension to match data # if flat_field is not None: # temp = np.ones_like(dark) # temp[::2, :, :] = flat_field[:, :] # flat_field = temp if use_gpu and flat_field is not None: flat_field = cp.asarray(flat_field) # compute dark corrected ROI values v = {} r_ff = {} ff = {} for n in rois.keys(): r[n] = r[n] - rd[n] if flat_field is not None: # TODO: flat-field should not be applied on intra darks # ff = flat_field[:, rois[n]['yl']:rois[n]['yh'], # rois[n]['xl']:rois[n]['xh']] ff[n] = flat_field[rois[n]['yl']:rois[n]['yh'], rois[n]['xl']:rois[n]['xh']] r_ff[n] = r[n]/ff[n] else: ff[n] = 1.0 r_ff[n] = r[n] v[n] = r_ff[n].sum(axis=(2, 3)) # np_mean roi where we normalize the sum of flat_field np_mean = (r['n'] + r['p'])/(ff['n'] + ff['p']) v['np_mean_sk'] = np_mean.sum(axis=(2,3)) res = xr.Dataset() dims = ['trainId', 'pulseId'] r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])} for n in rois.keys(): res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims) res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]), coords=r_coords, dims=dims) res['np_mean_sk'] = xr.DataArray(ensure_on_host(v['np_mean_sk']), coords=r_coords, dims=dims) res['np_mean_sk_sat'] = res['n_sat'] + res['p_sat'] for n in rois.keys(): roi = rois[n] res[n + '_area'] = xr.DataArray(np.array([ (roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])])) res['np_mean_area'] = res['n_area'] + res['p_area'] return res
[docs]def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511, use_gpu=False): """Process dark and run data with corrections. Inputs ------ Fmodel: correction lookup table arr_dark: dark data arr: data rois: ['n', '0', 'p', 'sat'] rois mask: mask of good pixels flat_field: zone plate flat-field correction sat_level: integer, default 511, at which level pixel begin to saturate Returns ------- roi extracted intensities """ # dark process dark = average_module(arr_dark, F_INL=Fmodel).compute() # data process return process_module(arr, tid, dark, rois, mask, sat_level=sat_level, flat_field=flat_field, F_INL=Fmodel, use_gpu=use_gpu).compute()
[docs]def inspect_saturation(data, gain, Nbins=200): """Plot roi integrated histogram of the data with saturation Inputs ------ data: xarray of roi integrated DSSC data gain: nominal DSSC gain in ph/bin Nbins: number of bins for the histogram, by default 200 Returns ------- f: handle to the matplotlib figure h: xarray of the histogram data """ d = data.where(data['sat_sat'] == False, drop=True) s = data.where(data['sat_sat'] == True, drop=True) # percentage of saturated shots N_nonsat = d['n'].count() N_all = data.dims['trainId'] * data.dims['pulseId'] sat_percent = ((N_all - N_nonsat)/N_all).values*100.0 # find the bin ranges sum_v = {} low = 0 high = 0 scale = 1e-6 for k in ['n', '0', 'p']: v = data[k].values.ravel()*scale*gain sum_v[k] = np.nansum(v) v_low, v_high = np.nanmin(v), np.nanmax(v) if v_low < low: low = v_low if v_high > high: high = v_high # compute bins edges, center and width bins = np.linspace(low, high, Nbins+1) bins_c = 0.5*(bins[:-1] + bins[1:]) w = bins[1] - bins[0] fig, ax = plt.subplots(figsize=(6,4)) h = {} for kk, k in enumerate(['n', '0', 'p']): v_d = d[k].values.ravel()*scale*gain v_s = s[k].values.ravel()*scale*gain h[k+'_nosat'], bin_e = np.histogram(v_d, bins) h[k+'_sat'], bin_e = np.histogram(v_s, bins) # compute density normalization on all data norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat'])) ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm, h[k+'_nosat']/norm, facecolor=f"C{kk}", edgecolor='none', alpha=0.2) ax.plot(bins_c, h[k+'_nosat']/norm, label=k, c=f'C{kk}', alpha=0.4) ax.text(0.6, 0.9, f"saturation: {sat_percent:.2f}%", color='r', alpha=0.5, transform=plt.gca().transAxes) ax.legend() ax.set_xlabel(r'10$^6$ ph') ax.set_ylabel('density') # save data as xarray dataset dv = {} for k in h.keys(): dv[k] = {"dims": "N", "data": h[k]} ds = { "coords": {"N": {"dims": "N", "data": bins_c, "attrs": {"units": f"{scale:g} ph"}}}, "attrs": {"saturation (%)": sat_percent}, "dims": "N", "data_vars": dv} return fig, xr.Dataset.from_dict(ds)