"""
Beam splitting Off-axis Zone plate analysis routines.
Copyright (2021, 2022, 2023, 2024) SCS Team.
"""
import time
import datetime
import json
import warnings
import numpy as np
import xarray as xr
import dask.array as da
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm
from matplotlib.patches import Polygon
from extra_data import open_run
from extra_geom import DSSC_1MGeometry
from toolbox_scs.routines.XAS import xas
try:
import cupy as cp
_can_use_gpu = True
except ModuleNotFoundError:
_can_use_gpu = False
print('Cupy is not installed in this environment, no access to the GPU')
except ImportError:
_can_use_gpu = False
print('Not currently running on a GPU node')
__all__ = [
'parameters',
'get_roi_pixel_pos',
'bad_pixel_map',
'inspect_dark',
'histogram_module',
'inspect_histogram',
'find_rois',
'find_rois_from_params',
'inspect_rois',
'compute_flat_field_correction',
'inspect_flat_field_domain',
'inspect_plane_fitting',
'plane_fitting_domain',
'plane_fitting',
'ff_refine_crit',
'ff_refine_fit',
'nl_domain',
'nl_lut',
'nl_crit',
'nl_crit_sk',
'nl_fit',
'inspect_nl_fit',
'snr',
'inspect_Fnl',
'inspect_correction',
'inspect_correction_sk',
'load_dssc_module',
'average_module',
'process_module',
'process',
'inspect_saturation'
]
[docs]class parameters():
"""Parameters contains all input parameters for the BOZ corrections.
This is used in beam splitting off-axis zone plate spectrocopy analysis as
well as the during the determination of correction parameters themselves to
ensure they can be reproduced.
Inputs
------
proposal: int, proposal number
darkrun: int, run number for the dark run
run: int, run number for the data run
module: int, DSSC module number
gain: float, number of ph per bin
drop_intra_darks: drop every second DSSC frame
"""
def __init__(self, proposal, darkrun, run, module, gain,
drop_intra_darks=True):
self.proposal = proposal
self.darkrun = darkrun
self.run = run
self.module = module
self.pixel_pos = _get_pixel_pos(self.module)
self.gain = gain
self.drop_intra_darks = drop_intra_darks
self.mask = None
self.mask_idx = None
self.mean_th = (None, None)
self.std_th = (None, None)
self.rois = None
self.rois_th = None
self.ff_type = 'plane'
self.flat_field = None
self.flat_field_prod_th = (5.0, np.PINF)
self.flat_field_ratio_th = (np.NINF, 1.2)
self.plane_guess_fit = None
self.use_hex = False
self.force_mirror = True
self.ff_alpha = None
self.ff_max_iter = None
self._using_gpu = False
self.Fnl = None
self.nl_alpha = None
self.sat_level = None
self.nl_max_iter = None
# temporary data
self.arr_dark = None
self.tid_dark = None
self.arr = None
self.tid = None
[docs] def dask_load_persistently(self, dark_data_size_Gb=None,
data_size_Gb=None):
"""Load dask data array in memory.
Inputs
------
dark_data_size_Gb: float, optional size of dark to load in memory,
in Gb
data_size_Gb: float, optional size of data to load in memory, in Gb
"""
self.arr_dark, self.tid_dark = load_dssc_module(self.proposal,
self.darkrun, self.module, drop_intra_darks=self.drop_intra_darks,
persist=True, data_size_Gb=dark_data_size_Gb)
self.arr, self.tid = load_dssc_module(self.proposal, self.run,
self.module, drop_intra_darks=self.drop_intra_darks,
persist=True, data_size_Gb=data_size_Gb)
# make sure to rechunk the arrays
self.arr = self.arr.rechunk(('auto', -1, -1, -1))
self.arr_dark = self.arr_dark.rechunk(('auto', -1, -1, -1))
[docs] def use_gpu(self):
assert _can_use_gpu, 'Failed to import cupy'
gpu_mem_gb = cp.cuda.Device().mem_info[1] / 2**30
if gpu_mem_gb < 30:
print(f'Warning: GPU memory ({gpu_mem_gb}GB) may be insufficient')
if self._using_gpu:
return
assert (
self.arr is not None and
self.arr_dark is not None
), "Must load data before switching to GPU"
if self.mask is not None:
self.mask = cp.array(self.mask)
# moving full data to GPU
limit = 2**30
self.arr = da.array(
cp.array(self.arr.compute())
).rechunk(('auto', -1, -1, -1), block_size_limit=limit)
self.arr_dark = da.array(
cp.array(self.arr_dark.compute())
).rechunk(('auto', -1, -1, -1), block_size_limit=limit)
self._using_gpu = True
[docs] def set_mask(self, arr):
"""Set mask of bad pixels.
Inputs
------
arr: either a boolean array of a DSSC module image or a list of bad
pixel indices
"""
if type(arr) is not list:
self.mask_idx = np.argwhere(arr == False).tolist()
self.mask = arr
else:
self.mask_idx = arr
mask = np.ones((128, 512), dtype=bool)
for k in self.mask_idx:
mask[k[0], k[1]] = False
self.mask = mask
if self._using_gpu:
self.mask = cp.array(self.mask)
[docs] def get_mask(self):
"""Get the boolean array bad pixel of a DSSC module."""
return self.mask
[docs] def get_mask_idx(self):
"""Get the list of bad pixel indices."""
return self.mask_idx
[docs] def flat_field_guess(self, guess=None):
"""Set the flat-field guess parameter for the fit and returns it.
Inputs
------
guess: a list of 8 floats, the 4 first to define the plane
ax+by+cz+d=0 for 'n' beam and the 4 last for the 'p' beam
in case mirror symmetry is disbaled
"""
if guess is not None:
self.plane_guess_fit = guess
return self.plane_guess_fit
if self.plane_guess_fit is None:
if self.use_hex:
self.plane_guess_fit = [
-20, 0.0, 1.5, -0.5, 20, 0, 1.5, -0.5 ]
else:
self.plane_guess_fit = [
-0.2, -0.1, 1, -0.54, 0.2, -0.1, 1, -0.54]
return self.plane_guess_fit
[docs] def set_flat_field(self, ff_params, ff_type='plane',
prod_th=None, ratio_th=None):
"""Set the flat-field plane definition.
Inputs
------
ff_params: list of parameters
ff_type: string identifying the type of flat field normalization,
default is 'plane'.
"""
self.ff_type = ff_type
if type(ff_params) is not list:
self.flat_field = ff_params.tolist()
else:
self.flat_field = ff_params
if prod_th is not None:
self.flat_field_prod_th = prod_th
if ratio_th is not None:
self.flat_field_ratio_th = ratio_th
[docs] def get_flat_field(self):
"""Get the flat-field plane definition."""
if self.flat_field is None:
return None
else:
return np.array(self.flat_field)
[docs] def set_Fnl(self, Fnl):
"""Set the non-linear correction function."""
if isinstance(Fnl, list):
self.Fnl = Fnl
else:
self.Fnl = Fnl.tolist()
[docs] def get_Fnl(self):
"""Get the non-linear correction function."""
if self.Fnl is None:
return None
else:
if self._using_gpu:
return cp.array(self.Fnl)
else:
return np.array(self.Fnl)
[docs] def save(self, path='./'):
"""Save the parameters as a JSON file.
Inputs
------
path: str, where to save the file, default to './'
"""
v = {}
v['proposal'] = self.proposal
v['darkrun'] = self.darkrun
v['run'] = self.run
v['module'] = self.module
v['gain'] = self.gain
v['drop_intra_darks'] = self.drop_intra_darks
v['mask'] = self.mask_idx
v['mean_th'] = self.mean_th
v['std_th'] = self.std_th
v['rois'] = self.rois
v['rois_th'] = self.rois_th
v['ff_type'] = self.ff_type
v['flat_field'] = self.flat_field
v['flat_field_prod_th'] = self.flat_field_prod_th
v['flat_field_ratio_th'] = self.flat_field_ratio_th
v['plane_guess_fit'] = self.plane_guess_fit
v['use_hex'] = self.use_hex
v['force_mirror'] = self.force_mirror
v['ff_alpha'] = self.ff_alpha
v['ff_max_iter'] = self.ff_max_iter
v['Fnl'] = self.Fnl
v['nl_alpha'] = self.nl_alpha
v['sat_level'] = self.sat_level
v['nl_max_iter'] = self.nl_max_iter
fname = f'parameters_p{self.proposal}_d{self.darkrun}_r{self.run}.json'
with open(path + fname, 'w') as f:
json.dump(v, f)
print(path + fname)
@classmethod
[docs] def load(cls, fname):
"""Load parameters from a JSON file.
Inputs
------
fname: string, name a the JSON file to load
"""
with open(fname, 'r') as f:
v = json.load(f)
c = cls(v['proposal'], v['darkrun'], v['run'], v['module'], v['gain'],
v['drop_intra_darks'])
c.mean_th = v['mean_th']
c.std_th = v['std_th']
c.set_mask(v['mask'])
c.rois = v['rois']
c.rois_th = v['rois_th']
if 'ff_type' not in v:
v['ff_type'] = 'plane'
c.set_flat_field(v['flat_field'], v['ff_type'],
v['flat_field_prod_th'], v['flat_field_ratio_th'])
c.plane_guess_fit = v['plane_guess_fit']
c.use_hex = v['use_hex']
c.force_mirror = v['force_mirror']
c.ff_alpha = v['ff_alpha']
c.ff_max_iter = v['ff_max_iter']
c.set_Fnl(v['Fnl'])
c.nl_alpha = v['nl_alpha']
c.sat_level = v['sat_level']
c.nl_max_iter = v['nl_max_iter']
return c
[docs] def __str__(self):
f = f'proposal:{self.proposal} darkrun:{self.darkrun} run:{self.run}'
f += f' module:{self.module} gain:{self.gain} ph/bin\n'
f += f'drop intra darks:{self.drop_intra_darks}\n'
if self.mask_idx is not None:
f += f'mean threshold:{self.mean_th} std threshold:{self.std_th}\n'
f += f'mask:(#{len(self.mask_idx)}) {self.mask_idx}\n'
else:
f += 'mask:None\n'
f += f'rois threshold: {self.rois_th}\n'
f += f'rois: {self.rois}\n'
f += f'flat-field type: {self.ff_type}\n'
f += f'flat-field p: {self.flat_field} '
f += f'prod:{self.flat_field_prod_th} '
f += f'ratio:{self.flat_field_ratio_th}\n'
f += f'plane guess fit: {self.plane_guess_fit}\n'
f += f'use hexagons: {self.use_hex}\n'
f += f'enforce mirror symmetry: {self.force_mirror}\n'
f += f'ff alpha: {self.ff_alpha}, max. iter.: {self.ff_max_iter}\n'
if self.Fnl is not None:
f += f'dFnl: {np.array(self.Fnl) - np.arange(2**9)}\n'
f += f'nl alpha:{self.nl_alpha}, sat. level:{self.sat_level}, '
f += f' nl max. iter.:{self.nl_max_iter}'
else:
f += 'Fnl: None'
return f
def ensure_on_host(arr):
# load data back from GPU - if it was on GPU
if hasattr(arr, "__cuda_array_interface__"): # avoid importing CuPy
return arr.get()
elif isinstance(arr, (da.Array,)):
return arr.map_blocks(ensure_on_host)
return arr
# Hexagonal pixels related function
def _get_pixel_pos(module):
"""Compute the pixel position on hexagonal lattice of DSSC module."""
# module pixel position
dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
# keeping only module 15 pixel X,Y position
return g.get_pixel_positions()[module][:, :, :2]
[docs]def get_roi_pixel_pos(roi, params):
"""Compute fake or real pixel position of an roi from roi center.
Inputs:
-------
roi: dictionnary
params: parameters
Returns:
--------
X, Y: 1-d array of pixel position.
"""
if params.use_hex:
# DSSC pixel position on hexagonal lattice
X = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 0]
Y = params.pixel_pos[roi['yl']:roi['yh'], roi['xl']:roi['xh'], 1]
else:
nY, nX = roi['yh'] - roi['yl'], roi['xh'] - roi['xl']
X = np.arange(nX)/100
Y = np.arange(nY)[:, np.newaxis]/100
# center of ROI is put to 0,0
X -= np.mean(X)
Y -= np.mean(Y)
return X, Y
def _get_pixel_corners(module):
"""Compute the pixel corners of DSSC module."""
# module pixel position
dummy_quad_pos = [(-130, 5), (-130, -125), (5, -125), (5, 5)]
g = DSSC_1MGeometry.from_quad_positions(dummy_quad_pos)
# corners are in z,y,x oder so we rop z, flip x & y
corners = g.to_distortion_array(allow_negative_xy=True)
corners = corners[(module*128):((module+1)*128), :, :, 1:][:, :, :, ::-1]
return corners
def _get_pixel_hexagons(module):
"""Compute DSSC pixel hexagons for plotting.
Parameters:
-----------
module: either int, for the module number or a 2-d array of corners to
get hexagons from
Returns:
--------
a 1-d list of hexagons where corners position are in mm
"""
hexes = []
if type(module) is int:
corners = _get_pixel_corners(module)
else:
corners = module
for y in range(corners.shape[0]):
for x in range(corners.shape[1]):
c = 1e3*corners[y, x, :, :] # convert to mm
hexes.append(Polygon(c))
return hexes
def _add_colorbar(im, ax, loc='right', size='5%', pad=0.05):
"""Add a colobar on a new axes so it match the plot size.
Inputs
------
im: image plotted
ax: axes on which the image was plotted
loc: string, default 'right', location of the colorbar
size: string, default '5%', proportion of the colobar with respect to the
plotted image
pad: float, default 0.05, pad width between plot and colorbar
"""
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.append_axes(loc, size=size, pad=pad)
cbar = fig.colorbar(im, cax=cax)
return cbar
# dark related functions
[docs]def bad_pixel_map(params):
"""Compute the bad pixels map.
Inputs
------
params: parameters
Returns
-------
bad pixel map
"""
assert params.arr_dark is not None, "Data not loaded"
# compute mean and std
dark_mean = params.arr_dark.mean(axis=(0, 1)).compute()
dark_std = params.arr_dark.std(axis=(0, 1)).compute()
mask = np.ones_like(dark_mean)
if params.mean_th[0] is not None:
mask *= dark_mean >= params.mean_th[0]
if params.mean_th[1] is not None:
mask *= dark_mean <= params.mean_th[1]
if params.std_th[0] is not None:
mask *= dark_std >= params.std_th[0]
if params.std_th[1] is not None:
mask *= dark_std >= params.std_th[1]
print(f'# bad pixel: {int(128*512-mask.sum())}')
return mask.astype(bool)
[docs]def inspect_dark(arr, mean_th=(None, None), std_th=(None, None)):
"""Inspect dark run data and plot diagnostic.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mean_th: tuple of threshold (low, high), default (None, None), to compute
a mask of good pixels for which the mean dark value lie inside this
range
std_th: tuple of threshold (low, high), default (None, None), to compute a
mask of bad pixels for which the dark std value lie inside this
range
Returns
-------
fig: matplotlib figure
"""
# compute mean and std
dark_mean = ensure_on_host(arr.mean(axis=(0, 1)).compute())
dark_std = ensure_on_host(arr.std(axis=(0, 1)).compute())
fig = plt.figure(figsize=(7, 2.7))
gs = fig.add_gridspec(2, 4)
ax1 = fig.add_subplot(gs[0, 1:])
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax11 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 1:])
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax22 = fig.add_subplot(gs[1, 0])
vmin = np.percentile(dark_mean.flatten(), 2)
vmax = np.percentile(dark_mean.flatten(), 98)
im1 = ax1.pcolormesh(dark_mean, vmin=vmin, vmax=vmax)
ax1.invert_yaxis()
ax1.set_aspect('equal')
cbar1 = _add_colorbar(im1, ax=ax1, size='2%')
cbar1.ax.set_ylabel('dark mean')
ax11.hist(dark_mean.flatten(), bins=int(vmax*2-vmin/2+1),
range=(vmin/2, vmax*2))
if mean_th[0] is not None:
ax11.axvline(mean_th[0], c='k', alpha=0.5, ls='--')
if mean_th[1] is not None:
ax11.axvline(mean_th[1], c='k', alpha=0.5, ls='--')
ax11.set_yscale('log')
vmin = np.percentile(dark_std.flatten(), 2)
vmax = np.percentile(dark_std.flatten(), 98)
im2 = ax2.pcolormesh(dark_std, vmin=vmin, vmax=vmax)
ax2.invert_yaxis()
ax2.set_aspect('equal')
cbar2 = _add_colorbar(im2, ax=ax2, size='2%')
cbar2.ax.set_ylabel('dark std')
ax22.hist(dark_std.flatten(), bins=50, range=(vmin/2, vmax*2))
if std_th[0] is not None:
ax22.axvline(std_th[0], c='k', alpha=0.5, ls='--')
if std_th[1] is not None:
ax22.axvline(std_th[1], c='k', alpha=0.5, ls='--')
ax22.set_yscale('log')
return fig
# histogram related functions
[docs]def histogram_module(arr, mask=None):
"""Compute a histogram of the 9 bits raw pixel values over a module.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
mask: optional bad pixel mask
Returns
-------
histogram
"""
if mask is not None:
w = da.repeat(da.repeat(da.array(mask[None, None, :, :]),
arr.shape[1], axis=1), arr.shape[0], axis=0)
w = w.rechunk(arr.chunks)
return da.bincount(arr.ravel(), w.ravel(), minlength=512).compute()
else:
return da.bincount(arr.ravel(), minlength=512).compute()
[docs]def inspect_histogram(arr, arr_dark=None, mask=None, extra_lines=False):
"""Compute and plot a histogram of the 9 bits raw pixel values.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
arr: dask array of reshaped dssc dark data (trainId, pulseId, x, y)
mask: optional bad pixel mask
extra_lines: boolean, default False, plot extra lines at period values
Returns
-------
(h, hd): histogram of arr, arr_dark
figure
"""
from matplotlib.ticker import MultipleLocator
f = plt.figure(figsize=(6, 3))
ax = plt.gca()
h = ensure_on_host(histogram_module(arr, mask=mask))
Sum_h = np.sum(h)
ax.plot(np.arange(2**9), h/Sum_h, marker='o',
ms=3, markerfacecolor='none', lw=1)
if arr_dark is not None:
hd = ensure_on_host(histogram_module(arr_dark, mask=mask))
Sum_hd = np.sum(hd)
ax.plot(np.arange(2**9), hd/Sum_hd, marker='o',
ms=3, markerfacecolor='none', lw=1, c='k', alpha=.5)
else:
hd = None
if extra_lines:
for k in range(50, 271):
if not (k - 2) % 8:
ax.axvline(k, c='k', alpha=0.5, ls='--')
if not (k - 3) % 16:
ax.axvline(k, c='g', alpha=0.3, ls='--')
if not (k - 7) % 32:
ax.axvline(k, c='r', alpha=0.3, ls='--')
ax.axvline(271, c='C1', alpha=0.5, ls='--')
ax.set_xlim([0, 2**9-1])
ax.set_yscale('log')
ax.xaxis.set_minor_locator(MultipleLocator(10))
ax.set_xlabel('DSSC pixel value')
ax.set_ylabel('count frequency')
return (h, hd), f
# rois related function
[docs]def find_rois(data_mean, threshold, extended=False):
"""Find rois from 3 beams configuration.
Inputs
------
data_mean: dark corrected average image
threshold: threshold value to find beams
extended: boolean, True to define additional ASICS based rois
Returns
-------
rois: dictionnary of rois
"""
# compute vertical and horizontal projection
pX = data_mean.mean(axis=0)
pX = pX[:256] # half the ladder since there is a gap in the middle
pY = data_mean.mean(axis=1)
pX = pX/np.max(pX)
pY = pY/np.max(pY)
# along X
lowX = int(np.argmax(pX > threshold) - 1) # 1st occurrence returned
highX = int(pX.shape[0] -
np.argmax(pX[::-1] > threshold)) # last occ. returned
midX = int(0.5*(lowX+highX))
leftX2 = int(np.argmax(pX[lowX+5:midX-5] < threshold)) + lowX + 5
midX2 = int(np.argmax(pX[midX+5:highX-5] < threshold)) + midX + 5
midX1 = int(midX - 5 - np.argmax(pX[midX-5:lowX+5:-1] < threshold))
rightX1 = int(highX - 5 - np.argmax(pX[highX-5:midX+5:-1] < threshold))
# along Y
lowY = int(np.argmax(pY > threshold) - 1) # 1st occurrence returned
highY = int(pY.shape[0]
- np.argmax(pY[::-1] > threshold)) # last occ. returned
# define rois
rois = {}
# beam roi
rois['n'] = {'xl': lowX, 'xh': leftX2, 'yl': lowY, 'yh': highY}
rois['0'] = {'xl': midX1, 'xh': midX2, 'yl': lowY, 'yh': highY}
rois['p'] = {'xl': rightX1, 'xh': highX, 'yl': lowY, 'yh': highY}
# saturation roi
rois['sat'] = {'xl': lowX, 'xh': highX, 'yl': lowY, 'yh': highY}
if extended:
# baseline correction rois
for k in [0, 1, 2, 3]:
rois[f'b{k}'] = {'xl': k*64, 'xh': (k+1)*64, 'yl': 0, 'yh': lowY}
for k in [8, 9, 10, 11]:
rois[f'b{k}'] = {'xl': (k-8)*64, 'xh': (k+1-8)*64,
'yl': highY, 'yh': 128}
# ASICs splitted beam roi
rois['0X'] = {'xl': lowX, 'xh': 1*64, 'yl': lowY, 'yh': 64}
rois['1X1'] = {'xl': 64, 'xh': leftX, 'yl': lowY, 'yh': 64}
rois['1X2'] = {'xl': leftX, 'xh': 2*64, 'yl': lowY, 'yh': 64}
rois['2X1'] = {'xl': 2*64, 'xh': rightX, 'yl': lowY, 'yh': 64}
rois['2X2'] = {'xl': rightX, 'xh': 3*64, 'yl': lowY, 'yh': 64}
rois['3X'] = {'xl': 3*64, 'xh': highX, 'yl': lowY, 'yh': 64}
rois['8X'] = {'xl': lowX, 'xh': 1*64, 'yl': 64, 'yh': highY}
rois['9X1'] = {'xl': 64, 'xh': leftX, 'yl': 64, 'yh': highY}
rois['9X2'] = {'xl': leftX, 'xh': 2*64, 'yl': 64, 'yh': highY}
rois['10X1'] = {'xl': 2*64, 'xh': rightX, 'yl': 64, 'yh': highY}
rois['10X2'] = {'xl': rightX, 'xh': 3*64, 'yl': 64, 'yh': highY}
rois['11X'] = {'xl': 3*64, 'xh': highX, 'yl': 64, 'yh': highY}
return rois
[docs]def find_rois_from_params(params):
"""Find rois from 3 beams configuration.
Inputs
------
params: parameters
Returns
-------
rois: dictionnary of rois
"""
assert params.arr_dark is not None, "Data not loaded"
dark = average_module(params.arr_dark).compute()
assert params.arr is not None, "Data not loaded"
data = average_module(params.arr, dark=dark).compute()
data_mean = data.mean(axis=0) # mean over pulseId
threshold = params.rois_th
return find_rois(data_mean, threshold)
[docs]def inspect_rois(data_mean, rois, threshold=None, allrois=False):
"""Find rois from 3 beams configuration from mean module image.
Inputs
------
data_mean: mean module image
threshold: float, default None, threshold value used to detect beams
boundaries
allrois: boolean, default False, plot all rois defined in rois or only the
main ones (['n', '0', 'p'])
Returns
-------
matplotlib figure
"""
# compute vertical and horizontal projection
pX = data_mean.mean(axis=0)
pX = pX[:256] # half the ladder since there is a gap in the middle
pY = data_mean.mean(axis=1)
pX = pX/np.max(pX)
pY = pY/np.max(pY)
# Set up the axes with gridspec
fig = plt.figure(figsize=(5, 3))
grid = plt.GridSpec(2, 2, width_ratios=(1, 4), height_ratios=(2, 1),
# left=0.1, right=0.9, bottom=0.1, top=0.9,
wspace=0.05, hspace=0.05,
figure=fig)
main_ax = fig.add_subplot(grid[0, 1])
y = fig.add_subplot(grid[0, 0], xticklabels=[], sharey=main_ax)
x = fig.add_subplot(grid[1, 1], yticklabels=[], sharex=main_ax)
# scatter points on the main axes
Xs = np.arange(len(pX))
Ys = np.arange(len(pY))
main_ax.pcolormesh(Xs, Ys, np.flipud(data_mean[:, :256]),
cmap='Greys_r',
vmin=0,
vmax=np.percentile(data_mean[:, :256], 99))
main_ax.set_aspect('equal')
from matplotlib.patches import Rectangle
roi = rois['n']
main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
roi['xh'] - roi['xl'],
roi['yh'] - roi['yl'],
alpha=0.3, color='b'))
roi = rois['0']
main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
roi['xh'] - roi['xl'],
roi['yh'] - roi['yl'],
alpha=0.3, color='g'))
roi = rois['p']
main_ax.add_patch(Rectangle((roi['xl'], 128-roi['yh']),
roi['xh'] - roi['xl'],
roi['yh'] - roi['yl'],
alpha=0.3, color='r'))
x.plot(Xs, pX)
x.invert_yaxis()
if threshold is not None:
x.axhline(threshold, c='k', alpha=.5, ls='--')
x.axvline(rois['n']['xl'], c='b', alpha=.3)
x.axvline(rois['n']['xh'], c='b', alpha=.3)
x.axvline(rois['0']['xl'], c='g', alpha=.3)
x.axvline(rois['0']['xh'], c='g', alpha=.3)
x.axvline(rois['p']['xl'], c='r', alpha=.3)
x.axvline(rois['p']['xh'], c='r', alpha=.3)
y.plot(pY, np.arange(len(pY)-1, -1, -1))
y.invert_xaxis()
if threshold is not None:
y.axvline(threshold, c='k', alpha=.5, ls='--')
y.axhline(127-rois['p']['yl'], c='r', alpha=.5)
y.axhline(127-rois['p']['yh'], c='r', alpha=.5)
return fig
# Flat-field related functions
def _plane_flat_field(p, roi, params):
"""Compute the p plane over the given roi.
Given the plane parameters p, compute the plane over the roi
size.
Parameters
----------
p: a vector of a, b, c, d plane parameter with the
plane given by ax+ by + cz + d = 0
roi: a dictionnary roi['yh', 'yl', 'xh', 'xl']
params: parameters
Returns
-------
the plane field given by p evaluated on the roi
extend.
"""
a, b, c, d = p
X, Y = get_roi_pixel_pos(roi, params)
Z = -(a*X + b*Y + d)/c
return Z
[docs]def compute_flat_field_correction(rois, params, plot=False):
if params.ff_type == 'plane':
return compute_plane_flat_field_correction(rois, params, plot)
elif params.ff_type == 'polyline':
return compute_polyline_flat_field_correction(rois, params, plot)
else:
raise ValueError(f'Uknown flat field type {params.ff_type}')
def compute_plane_flat_field_correction(rois, params, plot=False):
"""Compute the plane-field correction on beam rois.
Inputs
------
rois: dictionnary of beam rois['n', '0', 'p']
params: parameters
plot: boolean, True by default, diagnostic plot
Returns
-------
numpy 2D array of the flat-field correction evaluated over one DSSC ladder
(2 sensors)
"""
flat_field = np.ones((128, 512))
plane = params.get_flat_field()
force_mirror = params.force_mirror
r = rois['n']
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field(plane[:4], r, params)
r = rois['p']
if force_mirror:
a, b, c, d = plane[:4]
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field([-a, b, c, d], r, params)
else:
flat_field[r['yl']:r['yh'], r['xl']:r['xh']] = \
_plane_flat_field(plane[4:], r, params)
if plot:
f, ax = plt.subplots(1, 1, figsize=(6, 2))
img = ax.pcolormesh(
np.flipud(flat_field[:, :256]), cmap='Greys_r')
f.colorbar(img, ax=[ax], label='amplitude')
ax.set_xlabel('px')
ax.set_ylabel('px')
ax.set_aspect('equal')
return flat_field
def initialize_polyline_ff_correction(avg, rois, params, plot=False):
"""Initialize the polyline flat field correction.
Inputs
------
avg: 2D array, average module image
rois: dictionnary of ROIs.
plot: boolean, plot initialized polyline versus data projection
Returns
-------
fig: handle to figure or None
"""
refn = avg[rois['n']['yl']:rois['n']['yh'],
rois['n']['xl']:rois['n']['xh']]
refp = avg[rois['p']['yl']:rois['p']['yh'],
rois['p']['xl']:rois['p']['xh']]
mid = avg[rois['0']['yl']:rois['0']['yh'],
rois['0']['xl']:rois['0']['xh']]
mref = 0.5*(refn + refp)
inv_signal = mref/mid # normalization
H_projection = inv_signal[:, :].mean(axis=0)
x = np.arange(0, len(H_projection))
H_z = np.polyfit(x, H_projection, 6)
H_p = np.poly1d(H_z)
V_projection = (inv_signal/H_p(x))[:, :].mean(axis=1)
y = np.arange(0, len(V_projection))
V_z = np.polyfit(y, V_projection, 6)
if plot:
fig, axs = plt.subplots(2, 1, figsize=(4,6))
axs[0].plot(x, H_projection, label='data (n+p)/2x0')
axs[0].plot(x, H_p(x), label='poly')
axs[0].legend()
axs[0].set_xlabel('x (px)')
axs[0].set_ylabel('H projection')
axs[1].plot(y, V_projection, label='data (n+p)/2x0')
V_p = np.poly1d(V_z)
axs[1].plot(y, V_p(y), label='poly')
axs[1].legend()
axs[1].set_xlabel('y (px)')
axs[1].set_ylabel('V projection')
else:
fig = None
# scaling on polynom coefficients for better fitting
ff = np.array([H_z/np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0]),
V_z/np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])])
params.set_flat_field(ff.flatten())
params.ff_type = 'polyline'
return fig
def compute_polyline_flat_field_correction(rois, params, plot=False):
"""Compute the 1D polyline field correction on beam rois.
Inputs
------
rois: dictionnary of beam rois['n', '0', 'p']
params: parameters
plot: boolean, True by default, diagnostic plot
Returns
-------
numpy 2D array of the flat-field correction evaluated over one DSSC ladder
(2 sensors)
"""
flat_field = np.ones((128, 512))
z = np.array(params.get_flat_field()).reshape((2, -1))
H_z = z[0, :]
V_z = z[1, :]
coeffs = np.logspace(-(H_z.shape[0]-1), 0, H_z.shape[0])
H_p = np.poly1d(H_z*coeffs)
coeffs = np.logspace(-(V_z.shape[0]-1), 0, V_z.shape[0])
V_p = np.poly1d(V_z*coeffs)
n = rois['n']
p = rois['p']
wn = n['xh']-n['xl']
wp = p['xh']-p['xl']
assert wn == wp, (\
f"For polyline flat field normalization, both 'n' and 'p' ROIs "
f"must have the same width {wn} and {wp}px"
)
x = np.arange(wn)
wn = n['yh']-n['yl']
y = np.arange(wn)
norm = V_p(y)[:, np.newaxis]*H_p(x)
n_int = flat_field[n['yl']:n['yh'], n['xl']:n['xh']]
flat_field[n['yl']:n['yh'], n['xl']:n['xh']] = \
norm*n_int
p_int = flat_field[p['yl']:p['yh'], p['xl']:p['xh']]
flat_field[p['yl']:p['yh'], p['xl']:p['xh']] = \
norm*p_int # not the mirror
if plot:
f, ax = plt.subplots(1, 1, figsize=(6, 2))
img = ax.pcolormesh(
np.flipud(flat_field[:, :256]), cmap='Greys_r')
f.colorbar(img, ax=[ax], label='amplitude')
ax.set_xlabel('px')
ax.set_ylabel('px')
ax.set_aspect('equal')
return flat_field
[docs]def inspect_flat_field_domain(avg, rois, prod_th, ratio_th, vmin=None,
vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary or ROIs
prod_th, ratio_th: tuple of floats for low and high threshold on
product and ratio
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
fig: matplotlib figure plotted
domain: a tuple (n_m, p_m) of domain for the 'n' and 'p' order
"""
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
fig, axs = plt.subplots(3, 3, sharex=True, figsize=(6, 9))
img_rois = {}
centers = {}
for k, r in enumerate(['n', '0', 'p']):
roi = rois[r]
centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
(roi['xl'] + roi['xh'])//2])
d = '0'
roi = rois[d]
for k, r in enumerate(['n', '0', 'p']):
img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
roi['yl']:roi['yh'], roi['xl']:roi['xh']]
im = axs[0, k].imshow(img_rois[r],
vmin=vmin,
vmax=vmax)
n, n_m, p, p_m = plane_fitting_domain(avg, rois, prod_th, ratio_th)
prod_vmin, prod_vmax, ratio_vmin, ratio_vmax = [None]*4
for k, r in enumerate(['n', '0', 'p']):
v = img_rois[r]*img_rois['0']
if prod_vmin is None:
prod_vmin = np.percentile(v, .5)
prod_vmax = np.percentile(v, 20) # we look for low intensity region
im2 = axs[1, k].imshow(v, vmin=prod_vmin, vmax=prod_vmax, cmap='magma')
axs[1,k].contour(v, prod_th, cmap=cm.get_cmap(cm.cool, 2))
v = img_rois[r]/img_rois['0']
if ratio_vmin is None:
ratio_vmin = np.percentile(v, 5)
ratio_vmax = np.percentile(v, 99.8)
im3 = axs[2, k].imshow(v, vmin=ratio_vmin, vmax=ratio_vmax,
cmap='RdBu_r')
axs[2,k].contour(v, ratio_th, cmap=cm.get_cmap(cm.cool, 2))
cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
cbar.ax.set_xlabel('data mean')
cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
cbar.ax.set_xlabel('product')
cbar = fig.colorbar(im3, ax=axs[2, :], orientation="horizontal")
cbar.ax.set_xlabel('ratio')
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
domain = (n_m, p_m)
return fig, domain
[docs]def inspect_plane_fitting(avg, rois, domain=None, vmin=None, vmax=None):
warnings.warn("This method is depreciated, use inspect_ff_fitting instead")
return inspect_ff_fitting(avg, rois, domain, vmin, vmax)
def inspect_ff_fitting(avg, rois, domain=None, vmin=None, vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary of rois
domain: list of domain mask for the -1st and +1st order
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
fig: matplotlib figure plotted
"""
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
fig, axs = plt.subplots(2, 3, sharex=True, figsize=(6, 6))
img_rois = {}
centers = {}
for k, r in enumerate(['n', '0', 'p']):
roi = rois[r]
centers[r] = np.array([(roi['yl'] + roi['yh'])//2,
(roi['xl'] + roi['xh'])//2])
d = '0'
roi = rois[d]
for k, r in enumerate(['n', '0', 'p']):
img_rois[r] = np.roll(avg, tuple(centers[d] - centers[r]))[
roi['yl']:roi['yh'], roi['xl']:roi['xh']]
im = axs[0, k].imshow(img_rois[r],
vmin=vmin,
vmax=vmax)
for k, r in enumerate(['n', '0', 'p']):
v = img_rois[r]/img_rois['0']
im2 = axs[1, k].imshow(v, vmin=0.2, vmax=1.1, cmap='RdBu_r')
if domain is not None:
n_m, p_m = domain
axs[1, 0].contour(n_m)
axs[1, 2].contour(p_m)
cbar = fig.colorbar(im, ax=axs[0, :], orientation="horizontal")
cbar.ax.set_xlabel('data mean')
cbar = fig.colorbar(im2, ax=axs[1, :], orientation="horizontal")
cbar.ax.set_xlabel('ratio')
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
return fig
def inspect_ff_fitting_sk(avg, rois, ff, domain=None, vmin=None, vmax=None):
"""Extract beams roi from average image and compute the ratio.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary of rois
ff: 2D array, flat field normalization
domain: list of domain mask for the -1st and +1st order
vmin: imshow vmin level, default None will use 5 percentile value
vmax: imshow vmax level, default None will use 99.8 percentile value
Returns
-------
fig: matplotlib figure plotted
"""
if vmin is None:
vmin = np.percentile(avg, 5)
if vmax is None:
vmax = np.percentile(avg, 99.8)
refn = avg[rois['n']['yl']:rois['n']['yh'],
rois['n']['xl']:rois['n']['xh']]
refp = avg[rois['p']['yl']:rois['p']['yh'],
rois['p']['xl']:rois['p']['xh']]
mid = avg[rois['0']['yl']:rois['0']['yh'],
rois['0']['xl']:rois['0']['xh']]
mref = 0.5*(refn + refp)
ffn = ff[rois['n']['yl']:rois['n']['yh'],
rois['n']['xl']:rois['n']['xh']]
ffp = ff[rois['p']['yl']:rois['p']['yh'],
rois['p']['xl']:rois['p']['xh']]
ffmid = ff[rois['0']['yl']:rois['0']['yh'],
rois['0']['xl']:rois['0']['xh']]
np_norm = 0.5*(ffn+ffp)
mid_norm = ffmid
fig, axs = plt.subplots(3, 3, sharex=True, sharey=True,
figsize=(8, 4))
im = axs[0, 0].imshow(mref)
axs[0, 0].set_title('(n+p)/2')
fig.colorbar(im, ax=axs[0, 0])
im = axs[1, 0].imshow(mid)
axs[1, 0].set_title('0')
fig.colorbar(im, ax=axs[1, 0])
im = axs[2, 0].imshow(mid/mref-1, cmap='RdBu_r', vmin=-1, vmax=1)
axs[2, 0].set_title('2x0/(n+p) - 1')
fig.colorbar(im, ax=axs[2, 0])
im = axs[0, 1].imshow(np_norm)
axs[0, 1].set_title('norm: (n+p)/2')
fig.colorbar(im, ax=axs[0, 1])
im = axs[1, 1].imshow(mid_norm)
axs[1, 1].set_title('norm: 0')
fig.colorbar(im, ax=axs[1, 1])
im = axs[2, 1].imshow(mid_norm/np_norm-1, cmap='RdBu_r', vmin=-1, vmax=1)
axs[2, 1].set_title('norm: 2x0/(n+p) - 1')
fig.colorbar(im, ax=axs[2, 1])
im = axs[0, 2].imshow(mref/np_norm)
axs[0, 2].set_title('(n+p)/2 /norm')
fig.colorbar(im, ax=axs[0, 2])
im = axs[1, 2].imshow(mid/mid_norm)
axs[1, 2].set_title('0 /norm')
fig.colorbar(im, ax=axs[1, 2])
im = axs[2, 2].imshow((mid/mid_norm)/(mref/np_norm)-1,
cmap='RdBu_r', vmin=-1, vmax=1)
axs[2, 2].set_title('2x0/(n+p) - 1 /norm')
fig.colorbar(im, ax=axs[2, 2])
# fig.suptitle(f'{proposalNB}-run{runNB}-dark{darkrunNB} sat={sat_level}')
return fig
[docs]def plane_fitting_domain(avg, rois, prod_th, ratio_th):
"""Extract beams roi, compute their ratio and the domain.
Inputs
------
avg: module average image with no saturated shots for the flat-field
determination
rois: dictionnary or rois containing the 3 beams ['n', '0', 'p'] with '0'
as the reference beam in the middle
prod_th: float tuple, low and hight threshold level to determine the plane
fitting domain on the product image of the orders
ratio_th: float tuple, low and high threshold level to determine the plane
fitting domain on the ratio image of the orders
Returns
-------
n: img ratio 'n'/'0'
n_m: mask where the the product 'n'*'0' is higher than 5 indicting that the
img ratio 'n'/'0' is defined
p: img ratio 'p'/'0'
p_m: mask where the the product 'p'*'0' is higher than 5 indicting that the
img ratio 'p'/'0' is defined
"""
centers = {}
for k in ['n', '0', 'p']:
r = rois[k]
centers[k] = np.array([(r['yl'] + r['yh'])//2,
(r['xl'] + r['xh'])//2])
k = 'n'
r = rois[k]
num = avg[r['yl']:r['yh'], r['xl']:r['xh']]
d = '0'
denom = np.roll(avg, tuple(centers[k] - centers[d]))[
r['yl']:r['yh'], r['xl']:r['xh']]
n = num/denom
prod = num*denom
n_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
(n > ratio_th[0]) * (n < ratio_th[1]))
n_m[~np.isfinite(n)] = 0
n[~np.isfinite(n)] = 0
k = 'p'
r = rois[k]
num = avg[r['yl']:r['yh'], r['xl']:r['xh']]
d = '0'
denom = np.roll(avg, tuple(centers[k] - centers[d]))[
r['yl']:r['yh'], r['xl']:r['xh']]
p = num/denom
prod = num*denom
p_m = ((prod > prod_th[0]) * (prod < prod_th[1]) *
(p > ratio_th[0]) * (p < ratio_th[1]))
p_m[~np.isfinite(p)] = 0
p[~np.isfinite(p)] = 0
return n, n_m, p, p_m
[docs]def plane_fitting(params):
"""Fit the plane flat-field normalization.
Inputs
------
params: parameters
Returns
-------
res: the minimization result. The fitted vector res.x = [a, b, c, d]
defines the plane as a*x + b*y + c*z + d = 0
"""
assert params.arr_dark is not None, "Data not loaded"
dark = average_module(params.arr_dark).compute()
assert params.arr is not None, "Data not loaded"
data = average_module(params.arr, dark=dark,
ret='mean', mask=params.mask, sat_roi=params.rois['sat'],
sat_level=params.sat_level).compute()
data_mean = data.mean(axis=0) # mean over pulseId
n, n_m, p, p_m = plane_fitting_domain(data_mean, params.rois,
params.flat_field_prod_th, params.flat_field_ratio_th)
def _crit(x):
"""Fitting criteria for the plane field normalization.
Inputs
------
x: 2 vector [a, b, c, d] concatenated defining the plane as
a*x + b*y + c*z + d = 0
"""
a_n, b_n, c_n, d_n, a_p, b_p, c_p, d_p = x
num_n = a_n**2 + b_n**2 + c_n**2
roi = params.rois['n']
X, Y = get_roi_pixel_pos(roi, params)
d0_2 = np.sum(n_m*(a_n*X + b_n*Y + c_n*n + d_n)**2)/num_n
num_p = a_p**2 + b_p**2 + c_p**2
roi = params.rois['p']
X, Y = get_roi_pixel_pos(roi, params)
if params.force_mirror:
d2_2 = np.sum(p_m*(-a_n*X + b_n*Y + c_n*p + d_n)**2)/num_n
else:
d2_2 = np.sum(p_m*(a_p*X + b_p*Y + c_p*p + d_p)**2)/num_p
return 1e3*(d2_2 + d0_2)
res = minimize(_crit, params.flat_field_guess())
return res
[docs]def ff_refine_crit(p, alpha, params, arr_dark, arr, tid, rois,
mask, sat_level=511):
"""Criteria for the ff_refine_fit.
Inputs
------
p: ff plane
params: parameters
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
sum of standard deviation on binned 0th order intensity
"""
params.set_flat_field(p)
ff = compute_flat_field_correction(rois, params)
if np.any(ff < 0.0):
bad = 1e6
else:
bad = 0.0
data = process(None, arr_dark, arr, tid, rois, mask, ff,
sat_level, params._using_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
rn = xas(d, 40, Iokey='0', Itkey='n', nrjkey='0', fluorescence=True)
rp = xas(d, 40, Iokey='0', Itkey='p', nrjkey='0', fluorescence=True)
rd = xas(d, 40, Iokey='p', Itkey='n', nrjkey='0', fluorescence=True)
err_sigma = (np.nansum(rn['sigmaA']) + np.nansum(rp['sigmaA'])
+ np.nansum(rd['sigmaA']))
err_mean = ((1.0 - np.nanmean(rn['muA']))**2 +
(1.0 - np.nanmean(rp['muA']))**2 +
(1.0 - np.nanmean(rd['muA']))**2)
return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
def ff_refine_crit_sk(p, alpha, params, arr_dark, arr, tid, rois,
mask, sat_level=511):
"""Criteria for the ff_refine_fit, combining 'n' and 'p' as reference.
Inputs
------
p: ff plane
params: parameters
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
sum of standard deviation on binned 0th order intensity
"""
params.set_flat_field(p, params.ff_type)
ff = compute_flat_field_correction(rois, params)
if np.any(ff < 0.0):
bad = 1e6
else:
bad = 0.0
data = process(None, arr_dark, arr, tid, rois, mask, ff,
sat_level, params._using_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
r = xas(d, 40, Iokey='np_mean_sk', Itkey='0', nrjkey='0', fluorescence=True)
err_sigma = np.nansum(r['sigmaA'])
err_mean = (1.0 - np.nanmean(r['muA']))**2
return bad + 1e3*(alpha*err_sigma + (1-alpha)*err_mean)
[docs]def ff_refine_fit(params, crit=ff_refine_crit):
"""Refine the flat-field fit by minimizing data spread.
Inputs
------
params: parameters
Returns
-------
res: scipy minimize result. res.x is the optimized parameters
fitrres: iteration index arrays of criteria results for
[alpha=0, alpha, alpha=1]
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
p0 = params.get_flat_field()
if p0 is None: # flat field was not yet fitted
p0 = params.flat_field_guess()
fixed_p = (params.ff_alpha, params, params.arr_dark, params.arr,
params.tid, fitrois, params.get_mask(), params.sat_level)
def fit_callback(x):
if not hasattr(fit_callback, "counter"):
fit_callback.counter = 0 # it doesn't exist yet, so initialize it
fit_callback.start = time.monotonic()
fit_callback.res = []
now = time.monotonic()
time_delta = datetime.timedelta(seconds=now-fit_callback.start)
fit_callback.counter += 1
temp = list(fixed_p)
Jalpha = crit(x, *temp)
temp[0] = 0
J0 = crit(x, *temp)
temp[0] = 1
J1 = crit(x, *temp)
fit_callback.res.append([J0, Jalpha, J1])
print(f'{fit_callback.counter-1}: {time_delta} '
f'(reg. term: {J0}, {Jalpha}, err. term: {J1}), {x}')
return False
fit_callback(p0)
res = minimize(crit, p0, fixed_p,
options={'disp': True, 'maxiter': params.ff_max_iter},
callback=fit_callback)
return res, fit_callback.res
# non-linearity related functions
[docs]def nl_domain(N, low, high):
"""Create the input domain where the non-linear correction defined.
Inputs
------
N: integer, number of control points or intervals
low: input values below or equal to low will not be corrected
high: input values higher or equal to high will not be corrected
Returns
-------
array of 2**9 integer values with N segments
"""
x = np.arange(2**9)
vx = x.copy()
eps = 1e-5
vx[(x > low)*(x < high)] = np.linspace(1, N+1-eps, high-low-1)
vx[x <= low] = 0
vx[x >= high] = 0
return vx
[docs]def nl_lut(domain, dy):
"""Compute the non-linear correction.
Inputs
------
domain: input domain where dy is defined. For zero no correction is
defined. For non-zero value x, dy[x] is applied.
dy: a vector of deviation from linearity on control point homogeneously
dispersed over 9 bits.
Returns
-------
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
"""
x = np.arange(2**9)
ndy = np.insert(dy, 0, 0) # add zero to dy
f = x + ndy[domain]
return f
[docs]def nl_crit(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
sat_level=511, use_gpu=False):
"""Criteria for the non linear correction.
Inputs
------
p: vector of dy non linear correction
domain: domain over which the non linear correction is defined
alpha: float, coefficient scaling the cost of the correction function
in the criterion
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
flat_field: zone plate flat-field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
(1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
error squared from a transmission of 1.0 and err2 is the sum of the square
of the deviation from the ideal detector response.
"""
Fmodel = nl_lut(domain, p)
data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
arr, tid, rois, mask, flat_field, sat_level, use_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
v_1 = snr(d['n'].values.flatten(), d['0'].values.flatten(),
methods=['weighted'])
err_1 = 1e8*v_1['weighted']['s']**2
v_2 = snr(d['p'].values.flatten(), d['0'].values.flatten(),
methods=['weighted'])
err_2 = 1e8*v_2['weighted']['s']**2
err_a = np.sum((Fmodel-np.arange(2**9))**2)
return (1.0 - alpha)*0.5*(err_1 + err_2) + alpha*err_a
[docs]def nl_crit_sk(p, domain, alpha, arr_dark, arr, tid, rois, mask, flat_field,
sat_level=511, use_gpu=False):
"""Non linear correction criteria, combining 'n' and 'p' as reference.
Inputs
------
p: vector of dy non linear correction
domain: domain over which the non linear correction is defined
alpha: float, coefficient scaling the cost of the correction function
in the criterion
arr_dark: dark data
arr: data
tid: train id of arr data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask fo good pixels
flat_field: zone plate flat-field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
(1.0 - alpha)*err1 + alpha*err2, where err1 is the 1e8 times the mean of
error squared from a transmission of 1.0 and err2 is the sum of the square
of the deviation from the ideal detector response.
"""
Fmodel = nl_lut(domain, p)
data = process(Fmodel if not use_gpu else cp.asarray(Fmodel), arr_dark,
arr, tid, rois, mask, flat_field, sat_level, use_gpu)
# drop saturated shots
d = data.where(data['sat_sat'] == False, drop=True)
v = snr(d['np_mean_sk'].values.flatten(), d['0'].values.flatten(),
methods=['weighted'])
err = 1e8*v['weighted']['s']**2
err_a = np.sum((Fmodel-np.arange(2**9))**2)
return (1.0 - alpha)*err + alpha*err_a
[docs]def nl_fit(params, domain, ff=None, crit=None):
"""Fit non linearities correction function.
Inputs
------
params: parameters
domain: array of index
ff: array, flat field correction
crit: function, criteria function
Returns
-------
res: scipy minimize result. res.x is the optimized parameters
fitrres: iteration index arrays of criteria results for
[alpha=0, alpha, alpha=1]
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
# p0
N = np.unique(domain).shape[0] - 1
p0 = np.array([0]*N)
# flat flat_field
if ff is None:
ff = compute_flat_field_correction(params.rois, params)
if crit is None:
crit = nl_crit
fixed_p = (domain, params.nl_alpha, params.arr_dark, params.arr,
params.tid, fitrois, params.get_mask(), ff, params.sat_level,
params._using_gpu)
def fit_callback(x):
if not hasattr(fit_callback, "counter"):
fit_callback.counter = 0 # it doesn't exist yet, so initialize it
fit_callback.start = time.monotonic()
fit_callback.res = []
now = time.monotonic()
time_delta = datetime.timedelta(seconds=now-fit_callback.start)
fit_callback.counter += 1
temp = list(fixed_p)
Jalpha = crit(x, *temp)
temp[1] = 0
J0 = crit(x, *temp)
temp[1] = 1
J1 = crit(x, *temp)
fit_callback.res.append([J0, Jalpha, J1])
print(f'{fit_callback.counter-1}: {time_delta} '
f'({J0}, {Jalpha}, {J1}), {x}')
return False
fit_callback(p0)
res = minimize(crit, p0, fixed_p,
options={'disp': True, 'maxiter': params.nl_max_iter},
callback=fit_callback)
return res, fit_callback.res
[docs]def inspect_nl_fit(res_fit):
"""Plot the progress of the fit.
Inputs
------
res_fit:
Returns
-------
matplotlib figure
"""
r = np.array(res_fit)
f = plt.figure(figsize=(6, 4))
ax = f.gca()
ax2 = plt.twinx()
ax.plot(1.0/np.sqrt(1e-8*r[:, 0]), c='C0')
ax2.plot(r[:, 2], c='C1', ls='-.')
ax.set_xlabel('# iteration')
ax.set_ylabel('SNR', color='C0')
ax2.set_ylabel('correction cost', color='C1')
ax.set_yscale('log')
ax2.set_yscale('log')
return f
[docs]def snr(sig, ref, methods=None, verbose=False):
""" Compute mean, std and SNR from transmitted and I0 signals.
Inputs
------
sig: 1D signal samples
ref: 1D reference samples
methods: None by default or list of strings to select which methods to use.
Possible values are 'direct', 'weighted', 'diff'. In case of None, all
methods will be calculated.
verbose: booleand, if True prints calculated values
Returns
-------
dictionnary of [methods][value] where value is 'mu' for mean and 's' for
standard deviation.
"""
if methods is None:
methods = ['direct', 'weighted', 'diff']
w = ref
x = sig/ref
mask = np.isfinite(x) & np.isfinite(sig) & np.isfinite(ref)
w = w[mask]
sig = sig[mask]
ref = ref[mask]
x = x[mask]
res = {}
# direct mean and std
if 'direct' in methods:
mu = np.mean(x)
s = np.std(x)
if verbose:
print(f'mu: {mu}, s: {s}, snr: {mu/s}')
res['direct'] = {'mu': mu, 's':s}
# weighted mean and std
if 'weighted' in methods:
wmu = np.sum(sig)/np.sum(ref)
v1 = np.sum(w)
v2 = np.sum(w**2)
ws = np.sqrt(np.sum(w*(x - wmu)**2)/(v1 - v2/v1))
if verbose:
print(f'weighted mu: {wmu}, s: {ws}, snr: {wmu/ws}')
res['weighted'] = {'mu': wmu, 's':ws}
# noise from diff
if 'diff' in methods:
dmu = np.mean(x)
ds = np.std(np.diff(x))/np.sqrt(2)
if verbose:
print(f'diff mu: {dmu}, s: {ds}, snr: {dmu/ds}')
res['diff'] = {'mu': dmu, 's':ds}
return res
[docs]def inspect_Fnl(Fnl):
"""Plot the correction function Fnl.
Inputs
------
Fnl: non linear correction function lookup table
Returns
-------
matplotlib figure
"""
x = np.arange(2**9)
f = plt.figure(figsize=(6, 4))
plt.plot(x, Fnl - x)
# plt.axvline(40, c='k', ls='--')
# plt.axvline(280, c='k', ls='--')
plt.xlabel('input value')
plt.ylabel('output correction F(x)-x')
plt.xlim([0, 511])
return f
[docs]def inspect_correction(params, gain=None):
"""Comparison plot of the different corrections.
Inputs
------
params: parameters
gain: float, default None, DSSC gain in ph/bin
Returns
-------
matplotlib figure
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
# flat flat_field
plane_ff = params.get_flat_field()
if plane_ff is None:
plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
ff = compute_flat_field_correction(params.rois, params)
# non linearities
Fnl = params.get_Fnl()
if Fnl is None:
Fnl = np.arange(2**9)
xp = np if not params._using_gpu else cp
# compute all levels of correction
data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
params._using_gpu)
data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
# for conversion to nb of photons
if gain is None:
g = 1
else:
g = gain
scale = 1e-6
f, axs = plt.subplots(3, 3, figsize=(8, 6), sharex=True)
# nbins = np.linspace(0.01, 1.0, 100)
photon_scale = None
for k, d in enumerate([data, data_ff, data_ff_nl]):
for l, (n, r) in enumerate([('n', '0'), ('p', '0'), ('n', 'p')]):
if photon_scale is None:
lower = 0
upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
photon_scale = np.linspace(lower, upper, 150)
good_d = d.where(d['sat_sat'] == False, drop=True)
sat_d = d.where(d['sat_sat'], drop=True)
snr_v = snr(good_d[n].values.flatten(),
good_d[r].values.flatten(), verbose=True)
m = snr_v['direct']['mu']
h, xedges, yedges, img = axs[l, k].hist2d(
g*scale*good_d[r].values.flatten(),
good_d[n].values.flatten()/good_d[r].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Blues',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
h, xedges, yedges, img2 = axs[l, k].hist2d(
g*scale*sat_d[r].values.flatten(),
sat_d[n].values.flatten()/sat_d[r].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Reds',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
v = snr_v['direct']['mu']/snr_v['direct']['s']
axs[l, k].text(0.4, 0.15, f'SNR: {v:.0f}',
transform = axs[l, k].transAxes)
v = snr_v['weighted']['mu']/snr_v['weighted']['s']
axs[l, k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
transform = axs[l, k].transAxes)
#axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
#axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
axs[l, k].set_ylim([0.95*m, 1.05*m])
for k in range(3):
#for l in range(3):
# axs[l, k].set_ylim([0.95, 1.05])
if gain:
axs[2, k].set_xlabel('photons (10$^6$)')
else:
axs[2, k].set_xlabel('ADU (10$^6$)')
f.colorbar(img, ax=axs, label='events')
axs[0, 0].set_title('raw')
axs[0, 1].set_title('flat-field')
axs[0, 2].set_title('non-linear')
axs[0, 0].set_ylabel(r'-1$^\mathrm{st}$/0$^\mathrm{th}$ order')
axs[1, 0].set_ylabel(r'1$^\mathrm{st}$/0$^\mathrm{th}$ order')
axs[2, 0].set_ylabel(r'-1$^\mathrm{st}$/1$^\mathrm{th}$ order')
return f
[docs]def inspect_correction_sk(params, ff, gain=None):
"""Comparison plot of the different corrections, combining 'n' and 'p'.
Inputs
------
params: parameters
gain: float, default None, DSSC gain in ph/bin
Returns
-------
matplotlib figure
"""
# load data
assert params.arr is not None, "Data not loaded"
assert params.arr_dark is not None, "Data not loaded"
# we only need few rois
fitrois = {}
for k in ['n', '0', 'p', 'sat']:
fitrois[k] = params.rois[k]
# flat flat_field
#plane_ff = params.get_flat_field()
#if plane_ff is None:
# plane_ff = [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 1.0, -1.0]
#ff = compute_flat_field_correction(params.rois, params)
# non linearities
Fnl = params.get_Fnl()
if Fnl is None:
Fnl = np.arange(2**9)
xp = np if not params._using_gpu else cp
# compute all levels of correction
data = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), xp.ones_like(ff), params.sat_level,
params._using_gpu)
data_ff = process(xp.arange(2**9), params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
data_ff_nl = process(Fnl, params.arr_dark, params.arr, params.tid,
fitrois, params.get_mask(), ff, params.sat_level, params._using_gpu)
# for conversion to nb of photons
if gain is None:
g = 1
else:
g = gain
scale = 1e-6
f, axs = plt.subplots(1, 3, figsize=(8, 2.5), sharex=True)
# nbins = np.linspace(0.01, 1.0, 100)
photon_scale = None
for k, d in enumerate([data, data_ff, data_ff_nl]):
if photon_scale is None:
lower = 0
upper = g*scale*np.percentile(d['0'].values.flatten(), 99.9)
photon_scale = np.linspace(lower, upper, 150)
good_d = d.where(d['sat_sat'] == False, drop=True)
sat_d = d.where(d['sat_sat'], drop=True)
snr_v = snr(good_d['np_mean_sk'].values.flatten(),
good_d['0'].values.flatten(), verbose=True)
m = snr_v['direct']['mu']
h, xedges, yedges, img = axs[k].hist2d(
g*scale*good_d['0'].values.flatten(),
good_d['np_mean_sk'].values.flatten()/good_d['0'].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Blues',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
h, xedges, yedges, img2 = axs[k].hist2d(
g*scale*sat_d['0'].values.flatten(),
sat_d['np_mean_sk'].values.flatten()/sat_d['0'].values.flatten(),
[photon_scale, np.linspace(0.95, 1.05, 150)*m],
cmap='Reds',
norm=LogNorm(vmin=0.2, vmax=200),
# alpha=0.5 # make the plot looks ugly with lots of white lines
)
v = snr_v['direct']['mu']/snr_v['direct']['s']
axs[k].text(0.4, 0.15, f'SNR: {v:.0f}',
transform = axs[k].transAxes)
v = snr_v['weighted']['mu']/snr_v['weighted']['s']
axs[k].text(0.4, 0.05, r'SNR$_\mathrm{w}$: ' + f'{v:.0f}',
transform = axs[k].transAxes)
# axs[l, k].plot(3*nbins, 1+np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
# axs[l, k].plot(3*nbins, 1-np.sqrt(2/(1e6*nbins)), c='C1', ls='--')
axs[k].set_ylim([0.95*m, 1.05*m])
for k in range(3):
#for l in range(3):
# axs[l, k].set_ylim([0.95, 1.05])
if gain:
axs[k].set_xlabel('photons (10$^6$)')
else:
axs[k].set_xlabel('ADU (10$^6$)')
f.colorbar(img, ax=axs, label='events')
axs[0].set_title('raw')
axs[1].set_title('flat-field')
axs[2].set_title('non-linear')
axs[0].set_ylabel(r'np_mean/0')
return f
# data processing related functions
[docs]def load_dssc_module(proposalNB, runNB, moduleNB=15,
subset=slice(None), drop_intra_darks=True,
persist=False, data_size_Gb=None):
"""Load single module dssc data as dask array.
Inputs
------
proposalNB: proposal number
runNB: run number
moduleNB: default 15, module number
subset: default slice(None), subset of trains to load
drop_intra_darks: boolean, default True, remove intra darks from the data
persist: default False, load all data persistently in memory
data_size_Gb: float, if persist is True, can optionaly restrict
the amount of data loaded for dark data and run data in Gb
Returns
-------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
"""
run = open_run(proposal=proposalNB, run=runNB)
# DSSC
source = f'SCS_DET_DSSC1M-1/DET/{moduleNB}CH0:xtdf'
key = 'image.data'
arr = run[source, key][subset].dask_array()
# fix 256 value becoming spuriously 0 instead
arr[arr == 0] = 256
ppt = run[source, key][subset].data_counts()
# ignore train with no pulses, can happen in burst mode acquisition
ppt = ppt[ppt > 0]
tid = ppt.index.values
ppt = np.unique(ppt)
assert ppt.shape[0] == 1, "number of pulses changed during the run"
ppt = ppt[0]
# reshape in trainId, pulseId, 2d-image
arr = arr.reshape(-1, ppt, arr.shape[2], arr.shape[3])
# drop intra darks
if drop_intra_darks:
arr = arr[:, ::2, :, :]
# load data in memory
if persist:
if data_size_Gb is not None:
# keep only xGb of data
N = int(data_size_Gb*1024**3/(arr.shape[1]*128*512*2))
SLICE = slice(0, N)
arr = arr[SLICE]
tid = tid[SLICE]
arr = arr.persist()
return arr, tid
[docs]def average_module(arr, dark=None, ret='mean',
mask=None, sat_roi=None, sat_level=300, F_INL=None):
"""Compute the average or std over a module.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
dark: default None, dark to be substracted
ret: string, either 'mean' to compute the mean or 'std' to compute the
standard deviation
mask: default None, mask of bad pixels to ignore
sat_roi: roi over which to check for pixel with values larger than
sat_level to drop the image from the average or std
sat_level: int, minimum pixel value for a pixel to be considered saturated
F_INL: default None, non linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
average or standard deviation image
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype)
else:
narr = arr
if mask is not None:
narr = narr*mask
if sat_roi is not None:
not_sat = da.repeat(
da.repeat(
da.all(
narr[
:,
:,
sat_roi["yl"] : sat_roi["yh"],
sat_roi["xl"] : sat_roi["xh"],
]
< sat_level,
axis=[2, 3],
keepdims=True,
),
128,
axis=2,
),
512,
axis=3,
)
if dark is not None:
narr = narr - dark
if ret == 'mean':
if sat_roi is not None:
return da.average(narr, axis=0, weights=not_sat)
else:
return narr.mean(axis=0)
elif ret == 'std':
return narr.std(axis=0)
else:
raise ValueError(f'ret={ret} not supported')
[docs]def process_module(arr, tid, dark, rois, mask=None, sat_level=511,
flat_field=None, F_INL=None, use_gpu=False):
"""Process one module and extract roi intensity.
Inputs
------
arr: dask array of reshaped dssc data (trainId, pulseId, x, y)
tid: array of train id number
dark: pulse resolved dark image to remove
rois: dictionnary of rois
mask: default None, mask of ignored pixels
sat_level: integer, default 511, at which level pixel begin to saturate
flat_field: default None, flat-field correction
F_INL: default None, non-linear correction function given as a
lookup table with 9 bits integer input
Returns
-------
dataset of extracted pulse and train resolved roi intensities.
"""
# F_INL
if F_INL is not None:
narr = arr.map_blocks(lambda x: F_INL[x], dtype=F_INL.dtype)
else:
narr = arr
# apply mask
if mask is not None:
narr = narr*mask
# crop rois
r = {}
rd = {}
for n in rois.keys():
r[n] = narr[:, :, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
rd[n] = dark[:, rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
# find saturated shots
r_sat = {}
for n in rois.keys():
r_sat[n] = da.any(r[n] >= sat_level, axis=(2, 3))
# TODO: flat-field should not be applied on intra darks
# # change flat-field dimension to match data
# if flat_field is not None:
# temp = np.ones_like(dark)
# temp[::2, :, :] = flat_field[:, :]
# flat_field = temp
if use_gpu and flat_field is not None:
flat_field = cp.asarray(flat_field)
# compute dark corrected ROI values
v = {}
r_ff = {}
ff = {}
for n in rois.keys():
r[n] = r[n] - rd[n]
if flat_field is not None:
# TODO: flat-field should not be applied on intra darks
# ff = flat_field[:, rois[n]['yl']:rois[n]['yh'],
# rois[n]['xl']:rois[n]['xh']]
ff[n] = flat_field[rois[n]['yl']:rois[n]['yh'],
rois[n]['xl']:rois[n]['xh']]
r_ff[n] = r[n]/ff[n]
else:
ff[n] = 1.0
r_ff[n] = r[n]
v[n] = r_ff[n].sum(axis=(2, 3))
# np_mean roi where we normalize the sum of flat_field
np_mean = (r['n'] + r['p'])/(ff['n'] + ff['p'])
v['np_mean_sk'] = np_mean.sum(axis=(2,3))
res = xr.Dataset()
dims = ['trainId', 'pulseId']
r_coords = {'trainId': tid, 'pulseId': np.arange(0, narr.shape[1])}
for n in rois.keys():
res[n] = xr.DataArray(ensure_on_host(v[n]), coords=r_coords, dims=dims)
res[n + '_sat'] = xr.DataArray(ensure_on_host(r_sat[n][:, :]),
coords=r_coords, dims=dims)
res['np_mean_sk'] = xr.DataArray(ensure_on_host(v['np_mean_sk']),
coords=r_coords, dims=dims)
res['np_mean_sk_sat'] = res['n_sat'] + res['p_sat']
for n in rois.keys():
roi = rois[n]
res[n + '_area'] = xr.DataArray(np.array([
(roi['yh'] - roi['yl'])*(roi['xh'] - roi['xl'])]))
res['np_mean_area'] = res['n_area'] + res['p_area']
return res
[docs]def process(Fmodel, arr_dark, arr, tid, rois, mask, flat_field, sat_level=511,
use_gpu=False):
"""Process dark and run data with corrections.
Inputs
------
Fmodel: correction lookup table
arr_dark: dark data
arr: data
rois: ['n', '0', 'p', 'sat'] rois
mask: mask of good pixels
flat_field: zone plate flat-field correction
sat_level: integer, default 511, at which level pixel begin to saturate
Returns
-------
roi extracted intensities
"""
# dark process
dark = average_module(arr_dark, F_INL=Fmodel).compute()
# data process
return process_module(arr, tid, dark, rois, mask, sat_level=sat_level,
flat_field=flat_field, F_INL=Fmodel, use_gpu=use_gpu).compute()
[docs]def inspect_saturation(data, gain, Nbins=200):
"""Plot roi integrated histogram of the data with saturation
Inputs
------
data: xarray of roi integrated DSSC data
gain: nominal DSSC gain in ph/bin
Nbins: number of bins for the histogram, by default 200
Returns
-------
f: handle to the matplotlib figure
h: xarray of the histogram data
"""
d = data.where(data['sat_sat'] == False, drop=True)
s = data.where(data['sat_sat'] == True, drop=True)
# percentage of saturated shots
N_nonsat = d['n'].count()
N_all = data.dims['trainId'] * data.dims['pulseId']
sat_percent = ((N_all - N_nonsat)/N_all).values*100.0
# find the bin ranges
sum_v = {}
low = 0
high = 0
scale = 1e-6
for k in ['n', '0', 'p']:
v = data[k].values.ravel()*scale*gain
sum_v[k] = np.nansum(v)
v_low, v_high = np.nanmin(v), np.nanmax(v)
if v_low < low:
low = v_low
if v_high > high:
high = v_high
# compute bins edges, center and width
bins = np.linspace(low, high, Nbins+1)
bins_c = 0.5*(bins[:-1] + bins[1:])
w = bins[1] - bins[0]
fig, ax = plt.subplots(figsize=(6,4))
h = {}
for kk, k in enumerate(['n', '0', 'p']):
v_d = d[k].values.ravel()*scale*gain
v_s = s[k].values.ravel()*scale*gain
h[k+'_nosat'], bin_e = np.histogram(v_d, bins)
h[k+'_sat'], bin_e = np.histogram(v_s, bins)
# compute density normalization on all data
norm = w*(np.sum(h[k+'_nosat']) + np.sum(h[k+'_sat']))
ax.fill_between(bins_c, h[k+'_sat']/norm + h[k+'_nosat']/norm,
h[k+'_nosat']/norm, facecolor=f"C{kk}",
edgecolor='none', alpha=0.2)
ax.plot(bins_c, h[k+'_nosat']/norm, label=k,
c=f'C{kk}', alpha=0.4)
ax.text(0.6, 0.9, f"saturation: {sat_percent:.2f}%",
color='r', alpha=0.5, transform=plt.gca().transAxes)
ax.legend()
ax.set_xlabel(r'10$^6$ ph')
ax.set_ylabel('density')
# save data as xarray dataset
dv = {}
for k in h.keys():
dv[k] = {"dims": "N", "data": h[k]}
ds = {
"coords": {"N": {"dims": "N", "data": bins_c,
"attrs": {"units": f"{scale:g} ph"}}},
"attrs": {"saturation (%)": sat_percent},
"dims": "N",
"data_vars": dv}
return fig, xr.Dataset.from_dict(ds)