""" XGM related sub-routines
Copyright (2019) SCS Team.
(contributions preferrably comply with pep8 code structure
guidelines.)
"""
import logging
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from ..misc.bunch_pattern_external import is_sase_1, is_sase_3
from ..misc.bunch_pattern import (npulses_has_changed,
get_unique_sase_pId, load_bpt)
from ..mnemonics_machinery import mnemonics_for_run
from toolbox_scs.load import get_array
__all__ = [
'calibrate_xgm',
'get_xgm',
]
log = logging.getLogger(__name__)
[docs]def get_xgm(run, mnemonics=None, merge_with=None,
indices=slice(0, None)):
"""
Load and/or computes XGM data. Sources can be loaded on the
fly via the mnemonics argument, or processed from an existing dataset
(merge_with). The bunch pattern table is used to assign the pulse
id coordinates if the number of pulses has changed during the run.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the xgm data.
mnemonics: str or list of str
mnemonics for XGM, e.g. "SCS_SA3" or ["XTD10_XGM", "SCS_XGM"].
If None, defaults to "SCS_SA3" in case no merge_with dataset
is provided.
merge_with: xarray Dataset
If provided, the resulting Dataset will be merged with this
one. The XGM variables of merge_with (if any) will also be
computed and merged.
indices: slice, list, 1D array
Pulse indices of the XGM array in case bunch pattern is missing.
Returns
-------
xarray Dataset with pulse-resolved XGM variables aligned,
merged with Dataset *merge_with* if provided.
Example
-------
>>> import toolbox_scs as tb
>>> run, ds = tb.load(2212, 213, 'SCS_SA3')
>>> ds['SCS_SA3']
"""
xgm_mnemos = ['XTD10_SA', 'XTD10_XGM', 'SCS_SA', 'SCS_XGM']
m2 = []
if mnemonics is not None:
mnemonics = [mnemonics] if isinstance(mnemonics, str) else mnemonics
for m in mnemonics:
if any([(k in m) for k in xgm_mnemos]):
if merge_with is not None and m in merge_with:
continue
m2.append(m)
if merge_with is not None:
in_mw = []
for m, da in merge_with.items():
if any([(k in m) for k in xgm_mnemos]) and 'XGMbunchId' in da.dims:
in_mw.append(m)
m2 += in_mw
if len(m2) == 0:
log.info('no XGM mnemonics to process. Skipping.')
return merge_with
mnemonics = list(set(m2))
# Prepare the dataset of non-XGM data to merge with
if bool(merge_with):
ds_mw = merge_with.drop(mnemonics, errors='ignore')
else:
ds_mw = xr.Dataset()
# Load bunch pattern table and check pattern changes
bpt = load_bpt(run, ds_mw)
if bpt is not None:
sase1 = sase3 = None
mask_sa1 = mask_sa3 = None
if any([("SA1" in m) for m in mnemonics]) or any(
[("XGM" in m) for m in mnemonics]):
mask = mask_sa1 = is_sase_1(bpt)
sase1 = np.nonzero(mask_sa1[0].values)[0]
if any([("SA3" in m) for m in mnemonics]) or any(
[("XGM" in m) for m in mnemonics]):
mask = mask_sa3 = is_sase_3(bpt)
sase3 = np.nonzero(mask_sa3[0].values)[0]
if any([("XGM" in m) for m in mnemonics]):
mask = np.logical_or(mask_sa1, mask_sa3)
pattern_changed = ~(mask == mask[0]).all().values
ds = xr.Dataset()
for m in mnemonics:
if merge_with is not None and m in merge_with:
da_xgm = merge_with[m]
else:
da_xgm = get_array(run, m)
if bpt is not None:
if not pattern_changed:
ds_xgm = load_xgm_array(run, da_xgm, sase1, sase3)
else:
ds_xgm = align_xgm_array(da_xgm, bpt, mask_sa1, mask_sa3)
else:
xgm_val = da_xgm.values
xgm_val[xgm_val == 1] = np.nan
xgm_val[xgm_val == 0] = np.nan
da_xgm.values = xgm_val
da_xgm = da_xgm.dropna(dim='XGMbunchId', how='all')
ds_xgm = da_xgm.fillna(0).sel(XGMbunchId=indices).to_dataset()
ds = ds.merge(ds_xgm, join='inner')
# merge with non-XGM dataset
ds = ds_mw.merge(ds, join='inner')
return ds
def load_xgm_array(run, xgm, sase1, sase3):
"""
from a raw array xgm, extract and assign pulse Id coordinates
when the number of pulses does not change during the run.
If 'XGM' in mnemonic, the data is split in two variables
'SA1' and 'SA3'.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the xgm data.
xgm: xarray.DataArray
the raw XGM array
sase1: list or 1D array
the sase1 pulse ids
sase3: list or 1D array
the sase3 pulse ids
Returns
-------
ds_xgm: xarray.Dataset
the dataset containing the aligned XGM variable(s).
"""
xgm_val = xgm.values
xgm_val[xgm_val == 1] = np.nan
xgm_val[xgm_val == 0] = np.nan
xgm.values = xgm_val
xgm = xgm.dropna(dim='XGMbunchId', how='all')
xgm = xgm.fillna(0)
if 'XGM' in xgm.name:
sase1_3 = np.sort(np.concatenate([sase1, sase3]))
sase1_idx = [np.argwhere(sase1_3 == i)[0][0] for i in sase1]
sase3_idx = [np.argwhere(sase1_3 == i)[0][0] for i in sase3]
xgm_sa1 = xgm.isel(XGMbunchId=sase1_idx).rename(XGMbunchId='sa1_pId')
xgm_sa1 = xgm_sa1.assign_coords(sa1_pId=sase1)
xgm_sa1 = xgm_sa1.rename(xgm.name.replace('XGM', 'SA1'))
xgm_sa3 = xgm.isel(XGMbunchId=sase3_idx).rename(XGMbunchId='sa3_pId')
xgm_sa3 = xgm_sa3.assign_coords(sa3_pId=sase3)
xgm_sa3 = xgm_sa3.rename(xgm.name.replace('XGM', 'SA3'))
xgm = xr.merge([xgm_sa1, xgm_sa3])
elif 'SA1' in xgm.name:
xgm = xgm.rename(XGMbunchId='sa1_pId')
xgm = xgm.assign_coords(sa1_pId=sase1).rename(xgm.name)
xgm = xgm.to_dataset()
elif 'SA3' in xgm.name:
xgm = xgm.rename(XGMbunchId='sa3_pId')
xgm = xgm.assign_coords(sa3_pId=sase3).rename(xgm.name)
xgm = xgm.to_dataset()
return xgm
'''
def get_xgm_old(run, mnemonics=None, merge_with=None, keepAllSase=False,
indices=slice(0, None)):
"""
Load and/or computes XGM data. Sources can be loaded on the
fly via the key argument, or processed from an existing data set
(merge_with). The bunch pattern table is used to assign the pulse
id coordinates.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the digitizer data.
mnemonics: str or list of str
mnemonics for XGM, e.g. "SCS_SA3" or ["XTD10_XGM", "SCS_XGM"].
If None, defaults to "SCS_SA3" in case no merge_with dataset
is provided.
merge_with: xarray Dataset
If provided, the resulting Dataset will be merged with this
one. The XGM variables of merge_with (if any) will also be
computed and merged.
keepAllSase: bool
Only relevant in case of sase-dedicated trains. If True, all
trains are kept, else only those of the bunchPattern are kept.
indices: slice, list, 1D array
Pulse indices of the XGM array in case bunch pattern is missing.
Returns
-------
xarray Dataset with pulse-resolved XGM variables aligned,
merged with Dataset *merge_with* if provided.
Example
-------
>>> import toolbox_scs as tb
>>> import toolbox_scs.detectors as tbdet
>>> run, _ = tb.load(2212, 213)
>>> xgm = tbdet.get_xgm(run)
"""
# get the list of mnemonics to process
mnemonics = mnemonics_to_process(mnemonics, merge_with, 'XGM')
if len(mnemonics) == 0:
log.info('No array with unaligned XGM peaks to extract. Skipping.')
return merge_with
else:
log.info(f'Extracting XGM data from {mnemonics}.')
# Prepare the dataset of non-XGM data to merge with
if bool(merge_with):
ds = merge_with.drop(mnemonics, errors='ignore')
else:
ds = xr.Dataset()
run_mnemonics = mnemonics_for_run(run)
# check if bunch pattern table exists
if bool(merge_with) and 'bunchPatternTable' in merge_with:
bpt = merge_with['bunchPatternTable']
log.debug('Using bpt from merge_with dataset.')
elif 'bunchPatternTable' in run_mnemonics:
bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values())
log.debug('Loaded bpt from DataCollection.')
else:
bpt = None
# Load the arrays, assign pulse ID and merge
for m in mnemonics:
if bool(merge_with) and m in merge_with:
arr = merge_with[m]
log.debug(f'Using {m} from merge_with dataset.')
else:
arr = run.get_array(*run_mnemonics[m].values(), name=m)
log.debug(f'Loading {m} from DataCollection.')
if bpt is not None:
arr = align_xgm_array(arr, bpt)
else:
arr = arr.where(arr != 1., drop=True).sel(XGMbunchId=indices)
ds = ds.merge(arr, join='inner')
return ds
'''
def align_xgm_array(xgm_arr, bpt, mask_sa1, mask_sa3):
"""
Assigns pulse ID coordinates to a pulse-resolved XGM array, according to
the bunch pattern table. If the arrays contains both SASE 1 and SASE 3
data, it is split in two arrays.
Parameters
----------
xgm_arr: xarray DataArray
array containing pulse-resolved XGM data, with dims ['trainId',
'XGMbunchId']
bpt: xarray DataArray
bunch pattern table
mask_sa1: xarray DataArray
boolean 2D array (trainId x pulseId) of sase 1 pulses
mask_sa3: xarray DataArray
boolean 2D array (trainId x pulseId) of sase 3 pulses
Returns
-------
xgm: xarray Dataset
dataset with pulse ID coordinates. For SASE 1 data, the coordinates
name is sa1_pId, for SASE 3 data, the coordinates name is sa3_pId.
"""
key = xgm_arr.name
compute_sa1 = False
compute_sa3 = False
valid_tid = np.intersect1d(xgm_arr.trainId, bpt.trainId,
assume_unique=True)
# get the relevant masks for SASE 1 and/or SASE3
if "SA1" in key or "SA3" in key:
if "SA1" in key:
mask = mask_sa1.sel(trainId=valid_tid)
compute_sa1 = True
else:
mask = mask_sa3.sel(trainId=valid_tid)
compute_sa3 = True
tid = mask.where(mask.sum(dim='pulse_slot') > 0, drop=True).trainId
mask = mask.sel(trainId=tid)
mask_sa1 = mask.rename({'pulse_slot': 'sa1_pId'})
mask_sa3 = mask.rename({'pulse_slot': 'sa3_pId'})
if "XGM" in key:
compute_sa1 = True
compute_sa3 = True
mask_sa1 = mask_sa1.sel(trainId=valid_tid)
mask_sa3 = mask_sa3.sel(trainId=valid_tid)
mask = np.logical_or(mask_sa1, mask_sa3)
tid = mask.where(mask.sum(dim='pulse_slot') > 0,
drop=True).trainId
mask_sa1 = mask_sa1.sel(trainId=tid).rename({'pulse_slot': 'sa1_pId'})
mask_sa3 = mask_sa3.sel(trainId=tid).rename({'pulse_slot': 'sa3_pId'})
mask = mask.sel(trainId=tid)
npulses_max = mask.sum(dim='pulse_slot').max().values
bpt_npulses = bpt.sizes['pulse_slot']
xgm_arr = xgm_arr.sel(trainId=tid).isel(
XGMbunchId=slice(0, npulses_max))
# In rare cases, some xgm data is corrupted: trainId is valid but values
# are inf / NaN. We set them to -1 to avoid size mismatch between xgm and
# bpt. Before returning we will drop them.
xgm_arr = xgm_arr.where(np.isfinite(xgm_arr)).fillna(-1.)
# pad the xgm array to match the bpt dims, flatten and
# reorder xgm array to match the indices of the mask
xgm_flat = np.hstack((xgm_arr.fillna(1.),
np.ones((xgm_arr.sizes['trainId'],
bpt_npulses-npulses_max)))).flatten()
xgm_flat_arg = np.argwhere(xgm_flat != 1.)
mask_flat = mask.values.flatten()
mask_flat_arg = np.argwhere(mask_flat)
if(xgm_flat_arg.shape != mask_flat_arg.shape):
log.warning(f'{key}: XGM data and bunch pattern do not match.')
new_xgm_flat = np.ones(xgm_flat.shape)
new_xgm_flat[mask_flat_arg] = xgm_flat[xgm_flat_arg]
new_xgm = new_xgm_flat.reshape((xgm_arr.sizes['trainId'], bpt_npulses))
# create a dataset with new_xgm array masked by SASE 1 or SASE 3
xgm_dict = {}
if compute_sa1:
sa1_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa1_pId'],
coords={'trainId': xgm_arr.trainId,
'sa1_pId': np.arange(bpt_npulses)},
name=key.replace('XGM', 'SA1'))
sa1_xgm = sa1_xgm.where(mask_sa1, drop=True)
sa1_xgm = sa1_xgm.where(sa1_xgm != -1., drop=True)
# remove potential corrupted data:
xgm_dict[sa1_xgm.name] = sa1_xgm
if compute_sa3:
sa3_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa3_pId'],
coords={'trainId': xgm_arr.trainId,
'sa3_pId': np.arange(bpt_npulses)},
name=key.replace('XGM', 'SA3'))
sa3_xgm = sa3_xgm.where(mask_sa3, drop=True)
# remove potential corrupted data:
sa3_xgm = sa3_xgm.where(sa3_xgm != -1., drop=True)
xgm_dict[sa3_xgm.name] = sa3_xgm
ds = xr.Dataset(xgm_dict)
return ds
[docs]def calibrate_xgm(run, data, xgm='SCS', plot=False):
"""
Calculates the calibration factor F between the photon flux (slow signal)
and the fast signal (pulse-resolved) of the sase 3 pulses. The calibrated
fast signal is equal to the uncalibrated one multiplied by F.
Parameters
----------
run: extra_data.DataCollection
DataCollection containing the digitizer data.
data: xarray Dataset
dataset containing the pulse-resolved sase 3 signal, e.g. 'SCS_SA3'
xgm: str
one in {'XTD10', 'SCS'}
plot: bool
If True, shows a plot of the photon flux, averaged fast signal and
calibrated fast signal.
Returns
-------
F: float
calibration factor F defined as:
calibrated XGM [microJ] = F * fast XGM array ('SCS_SA3' or 'XTD10_SA3')
Example
-------
>>> import toolbox_scs as tb
>>> import toolbox_scs.detectors as tbdet
>>> run, data = tb.load(900074, 69, ['SCS_XGM'])
>>> ds = tbdet.get_xgm(run, merge_with=data)
>>> F = tbdet.calibrate_xgm(run, ds, plot=True)
>>> # Add calibrated XGM to the dataset:
>>> ds['SCS_SA3_uJ'] = F * ds['SCS_SA3']
"""
run_mnemonics = mnemonics_for_run(run)
# check if bunch pattern table exists
if 'bunchPatternTable' in data:
bpt = data['bunchPatternTable']
elif 'bunchPatternTable' in run_mnemonics:
bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values())
elif 'bunchPatternTable_SA3' in run_mnemonics:
bpt = run.get_array(*run_mnemonics['bunchPatternTable_SA3'].values())
else:
raise ValueError('Bunch pattern missing. Cannot calibrate XGM.')
mask_sa3 = is_sase_3(bpt.sel(trainId=data.trainId))
npulses_sa3 = np.unique(mask_sa3.sum(dim='pulse_slot'))
if len(npulses_sa3) == 1:
npulses_sa3 = npulses_sa3[0]
else:
log.warning('change of pulse pattern in sase3 during the run.')
npulses_sa3 = max(npulses_sa3)
mask_sa1 = is_sase_1(bpt.sel(trainId=data.trainId))
npulses_sa1 = np.unique(mask_sa1.sum(dim='pulse_slot'))
if len(npulses_sa1) == 1:
npulses_sa1 = npulses_sa1[0]
else:
log.warning('change of pulse pattern in sase1 during the run.')
npulses_sa1 = max(npulses_sa1)
pflux_key = f'{xgm}_photonFlux'
if pflux_key in data:
pflux = data[pflux_key]
else:
pflux = run.get_array(*run_mnemonics[pflux_key].values())
pflux = pflux.sel(trainId=data.trainId)
pflux_sa3 = (npulses_sa1 + npulses_sa3) * pflux / npulses_sa3
avg_fast = data[f'{xgm}_SA3'].rolling(trainId=200).mean().mean(axis=1)
calib = np.nanmean(pflux_sa3.values / avg_fast.values)
if plot:
plot_xgm_calibration(xgm, pflux, pflux_sa3, avg_fast, calib)
return calib
def plot_xgm_calibration(xgm, pflux, pflux_sa3, avg_fast, calib):
plt.figure(figsize=(8, 4))
plt.plot(pflux, label='photon flux all')
plt.plot(pflux_sa3, label='photon flux SA3')
plt.plot(avg_fast, label='avg pulsed XGM')
plt.plot(avg_fast*calib, label='calibrated avg pulsed XGM')
plt.title(f'calibrated XGM = {xgm}_SA3 * {calib:.3e}')
plt.xlabel('train number')
plt.ylabel(r'Pulse energy [$\mu$J]')
plt.legend()
return