""" XGM related sub-routines
    Copyright (2019) SCS Team.
    (contributions preferrably comply with pep8 code structure
    guidelines.)
"""
import logging
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from ..misc.bunch_pattern_external import is_sase_1, is_sase_3
from ..misc.bunch_pattern import (npulses_has_changed,
                                  get_unique_sase_pId, load_bpt)
from ..mnemonics_machinery import mnemonics_for_run
from toolbox_scs.load import get_array
__all__ = [
    'calibrate_xgm',
    'get_xgm',
]
log = logging.getLogger(__name__)
[docs]def get_xgm(run, mnemonics=None, merge_with=None,
            indices=slice(0, None)):
    """
    Load and/or computes XGM data. Sources can be loaded on the
    fly via the mnemonics argument, or processed from an existing dataset
    (merge_with). The bunch pattern table is used to assign the pulse
    id coordinates if the number of pulses has changed during the run.
    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the xgm data.
    mnemonics: str or list of str
        mnemonics for XGM, e.g. "SCS_SA3" or ["XTD10_XGM", "SCS_XGM"].
        If None, defaults to "SCS_SA3" in case no merge_with dataset
        is provided.
    merge_with: xarray Dataset
        If provided, the resulting Dataset will be merged with this
        one. The XGM variables of merge_with (if any) will also be
        computed and merged.
    indices: slice, list, 1D array
        Pulse indices of the XGM array in case bunch pattern is missing.
    Returns
    -------
    xarray Dataset with pulse-resolved XGM variables aligned,
     merged with Dataset *merge_with* if provided.
    Example
    -------
    >>> import toolbox_scs as tb
    >>> run, ds = tb.load(2212, 213, 'SCS_SA3')
    >>> ds['SCS_SA3']
    """
    xgm_mnemos = ['XTD10_SA', 'XTD10_XGM', 'SCS_SA', 'SCS_XGM']
    m2 = []
    if mnemonics is not None:
        mnemonics = [mnemonics] if isinstance(mnemonics, str) else mnemonics
        for m in mnemonics:
            if any([(k in m) for k in xgm_mnemos]):
                if merge_with is not None and m in merge_with:
                    continue
                m2.append(m)
    if merge_with is not None:
        in_mw = []
        for m, da in merge_with.items():
            if any([(k in m) for k in xgm_mnemos]) and 'XGMbunchId' in da.dims:
                in_mw.append(m)
        m2 += in_mw
    if len(m2) == 0:
        log.info('no XGM mnemonics to process. Skipping.')
        return merge_with
    mnemonics = list(set(m2))
    # Prepare the dataset of non-XGM data to merge with
    if bool(merge_with):
        ds_mw = merge_with.drop(mnemonics, errors='ignore')
    else:
        ds_mw = xr.Dataset()
    # Load bunch pattern table and check pattern changes
    bpt = load_bpt(run, ds_mw)
    if bpt is not None:
        sase1 = sase3 = None
        mask_sa1 = mask_sa3 = None
        if any([("SA1" in m) for m in mnemonics]) or any(
            [("XGM" in m) for m in mnemonics]):
            mask = mask_sa1 = is_sase_1(bpt)
            sase1 = np.nonzero(mask_sa1[0].values)[0]
        if any([("SA3" in m) for m in mnemonics]) or any(
            [("XGM" in m) for m in mnemonics]):
            mask = mask_sa3 = is_sase_3(bpt)
            sase3 = np.nonzero(mask_sa3[0].values)[0]
        if any([("XGM" in m) for m in mnemonics]):
            mask = np.logical_or(mask_sa1, mask_sa3)
        pattern_changed = ~(mask == mask[0]).all().values
    ds = xr.Dataset()
    for m in mnemonics:
        if merge_with is not None and m in merge_with:
            da_xgm = merge_with[m]
        else:
            da_xgm = get_array(run, m)
        if bpt is not None:
            if not pattern_changed:
                ds_xgm = load_xgm_array(run, da_xgm, sase1, sase3)
            else:
                ds_xgm = align_xgm_array(da_xgm, bpt, mask_sa1, mask_sa3)
        else:
            xgm_val = da_xgm.values
            xgm_val[xgm_val == 1] = np.nan
            xgm_val[xgm_val == 0] = np.nan
            da_xgm.values = xgm_val
            da_xgm = da_xgm.dropna(dim='XGMbunchId', how='all')
            ds_xgm = da_xgm.fillna(0).sel(XGMbunchId=indices).to_dataset()
        ds = ds.merge(ds_xgm, join='inner')
    # merge with non-XGM dataset
    ds = ds_mw.merge(ds, join='inner')
    return ds 
def load_xgm_array(run, xgm, sase1, sase3):
    """
    from a raw array xgm, extract and assign pulse Id coordinates
    when the number of pulses does not change during the run.
    If 'XGM' in mnemonic, the data is split in two variables
    'SA1' and 'SA3'.
    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the xgm data.
    xgm: xarray.DataArray
        the raw XGM array
    sase1: list or 1D array
        the sase1 pulse ids
    sase3: list or 1D array
        the sase3 pulse ids
    Returns
    -------
    ds_xgm: xarray.Dataset
        the dataset containing the aligned XGM variable(s).
    """
    xgm_val = xgm.values
    xgm_val[xgm_val == 1] = np.nan
    xgm_val[xgm_val == 0] = np.nan
    xgm.values = xgm_val
    xgm = xgm.dropna(dim='XGMbunchId', how='all')
    xgm = xgm.fillna(0)
    if 'XGM' in xgm.name:
        sase1_3 = np.sort(np.concatenate([sase1, sase3]))
        sase1_idx = [np.argwhere(sase1_3 == i)[0][0] for i in sase1]
        sase3_idx = [np.argwhere(sase1_3 == i)[0][0] for i in sase3]
        xgm_sa1 = xgm.isel(XGMbunchId=sase1_idx).rename(XGMbunchId='sa1_pId')
        xgm_sa1 = xgm_sa1.assign_coords(sa1_pId=sase1)
        xgm_sa1 = xgm_sa1.rename(xgm.name.replace('XGM', 'SA1'))
        xgm_sa3 = xgm.isel(XGMbunchId=sase3_idx).rename(XGMbunchId='sa3_pId')
        xgm_sa3 = xgm_sa3.assign_coords(sa3_pId=sase3)
        xgm_sa3 = xgm_sa3.rename(xgm.name.replace('XGM', 'SA3'))
        xgm = xr.merge([xgm_sa1, xgm_sa3])
    elif 'SA1' in xgm.name:
        xgm = xgm.rename(XGMbunchId='sa1_pId')
        xgm = xgm.assign_coords(sa1_pId=sase1).rename(xgm.name)
        xgm = xgm.to_dataset()
    elif 'SA3' in xgm.name:
        xgm = xgm.rename(XGMbunchId='sa3_pId')
        xgm = xgm.assign_coords(sa3_pId=sase3).rename(xgm.name)
        xgm = xgm.to_dataset()
    return xgm
'''
def get_xgm_old(run, mnemonics=None, merge_with=None, keepAllSase=False,
            indices=slice(0, None)):
    """
    Load and/or computes XGM data. Sources can be loaded on the
    fly via the key argument, or processed from an existing data set
    (merge_with). The bunch pattern table is used to assign the pulse
    id coordinates.
    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the digitizer data.
    mnemonics: str or list of str
        mnemonics for XGM, e.g. "SCS_SA3" or ["XTD10_XGM", "SCS_XGM"].
        If None, defaults to "SCS_SA3" in case no merge_with dataset
        is provided.
    merge_with: xarray Dataset
        If provided, the resulting Dataset will be merged with this
        one. The XGM variables of merge_with (if any) will also be
        computed and merged.
    keepAllSase: bool
        Only relevant in case of sase-dedicated trains. If True, all
        trains are kept, else only those of the bunchPattern are kept.
    indices: slice, list, 1D array
        Pulse indices of the XGM array in case bunch pattern is missing.
    Returns
    -------
    xarray Dataset with pulse-resolved XGM variables aligned,
     merged with Dataset *merge_with* if provided.
    Example
    -------
    >>> import toolbox_scs as tb
    >>> import toolbox_scs.detectors as tbdet
    >>> run, _ = tb.load(2212, 213)
    >>> xgm = tbdet.get_xgm(run)
    """
    # get the list of mnemonics to process
    mnemonics = mnemonics_to_process(mnemonics, merge_with, 'XGM')
    if len(mnemonics) == 0:
        log.info('No array with unaligned XGM peaks to extract. Skipping.')
        return merge_with
    else:
        log.info(f'Extracting XGM data from {mnemonics}.')
    # Prepare the dataset of non-XGM data to merge with
    if bool(merge_with):
        ds = merge_with.drop(mnemonics, errors='ignore')
    else:
        ds = xr.Dataset()
    run_mnemonics = mnemonics_for_run(run)
    # check if bunch pattern table exists
    if bool(merge_with) and 'bunchPatternTable' in merge_with:
        bpt = merge_with['bunchPatternTable']
        log.debug('Using bpt from merge_with dataset.')
    elif 'bunchPatternTable' in run_mnemonics:
        bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values())
        log.debug('Loaded bpt from DataCollection.')
    else:
        bpt = None
    # Load the arrays, assign pulse ID and merge
    for m in mnemonics:
        if bool(merge_with) and m in merge_with:
            arr = merge_with[m]
            log.debug(f'Using {m} from merge_with dataset.')
        else:
            arr = run.get_array(*run_mnemonics[m].values(), name=m)
            log.debug(f'Loading {m} from DataCollection.')
        if bpt is not None:
            arr = align_xgm_array(arr, bpt)
        else:
            arr = arr.where(arr != 1., drop=True).sel(XGMbunchId=indices)
        ds = ds.merge(arr, join='inner')
    return ds
'''
def align_xgm_array(xgm_arr, bpt, mask_sa1, mask_sa3):
    """
    Assigns pulse ID coordinates to a pulse-resolved XGM array, according to
    the bunch pattern table. If the arrays contains both SASE 1 and SASE 3
    data, it is split in two arrays.
    Parameters
    ----------
    xgm_arr: xarray DataArray
        array containing pulse-resolved XGM data, with dims ['trainId',
        'XGMbunchId']
    bpt: xarray DataArray
        bunch pattern table
    mask_sa1: xarray DataArray
        boolean 2D array (trainId x pulseId) of sase 1 pulses
    mask_sa3: xarray DataArray
        boolean 2D array (trainId x pulseId) of sase 3 pulses
    Returns
    -------
    xgm: xarray Dataset
        dataset with pulse ID coordinates. For SASE 1 data, the coordinates
        name is sa1_pId, for SASE 3 data, the coordinates name is sa3_pId.
    """        
    key = xgm_arr.name
    compute_sa1 = False
    compute_sa3 = False
    valid_tid = np.intersect1d(xgm_arr.trainId, bpt.trainId,
                               assume_unique=True)
    # get the relevant masks for SASE 1 and/or SASE3
    if "SA1" in key or "SA3" in key:
        if "SA1" in key:
            mask = mask_sa1.sel(trainId=valid_tid)
            compute_sa1 = True
        else:
            mask = mask_sa3.sel(trainId=valid_tid)
            compute_sa3 = True
        tid = mask.where(mask.sum(dim='pulse_slot') > 0, drop=True).trainId
        mask = mask.sel(trainId=tid)
        mask_sa1 = mask.rename({'pulse_slot': 'sa1_pId'})
        mask_sa3 = mask.rename({'pulse_slot': 'sa3_pId'})
    if "XGM" in key:
        compute_sa1 = True
        compute_sa3 = True
        mask_sa1 = mask_sa1.sel(trainId=valid_tid)
        mask_sa3 = mask_sa3.sel(trainId=valid_tid)
        mask = np.logical_or(mask_sa1, mask_sa3)
        tid = mask.where(mask.sum(dim='pulse_slot') > 0,
                         drop=True).trainId
        mask_sa1 = mask_sa1.sel(trainId=tid).rename({'pulse_slot': 'sa1_pId'})
        mask_sa3 = mask_sa3.sel(trainId=tid).rename({'pulse_slot': 'sa3_pId'})
        mask = mask.sel(trainId=tid)
    npulses_max = mask.sum(dim='pulse_slot').max().values
    bpt_npulses = bpt.sizes['pulse_slot']
    xgm_arr = xgm_arr.sel(trainId=tid).isel(
                XGMbunchId=slice(0, npulses_max))
    # In rare cases, some xgm data is corrupted: trainId is valid but values
    # are inf / NaN. We set them to -1 to avoid size mismatch between xgm and
    # bpt. Before returning we will drop them.
    xgm_arr = xgm_arr.where(np.isfinite(xgm_arr)).fillna(-1.)
    # pad the xgm array to match the bpt dims, flatten and
    # reorder xgm array to match the indices of the mask
    xgm_flat = np.hstack((xgm_arr.fillna(1.),
                          np.ones((xgm_arr.sizes['trainId'],
                                   bpt_npulses-npulses_max)))).flatten()
    xgm_flat_arg = np.argwhere(xgm_flat != 1.)
    mask_flat = mask.values.flatten()
    mask_flat_arg = np.argwhere(mask_flat)
    if(xgm_flat_arg.shape != mask_flat_arg.shape):
        log.warning(f'{key}: XGM data and bunch pattern do not match.')
    new_xgm_flat = np.ones(xgm_flat.shape)
    new_xgm_flat[mask_flat_arg] = xgm_flat[xgm_flat_arg]
    new_xgm = new_xgm_flat.reshape((xgm_arr.sizes['trainId'], bpt_npulses))
    # create a dataset with new_xgm array masked by SASE 1 or SASE 3
    xgm_dict = {}
    if compute_sa1:
        sa1_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa1_pId'],
                               coords={'trainId': xgm_arr.trainId,
                                       'sa1_pId': np.arange(bpt_npulses)},
                               name=key.replace('XGM', 'SA1'))
        sa1_xgm = sa1_xgm.where(mask_sa1, drop=True)
        sa1_xgm = sa1_xgm.where(sa1_xgm != -1., drop=True)
        # remove potential corrupted data:
        xgm_dict[sa1_xgm.name] = sa1_xgm
    if compute_sa3:
        sa3_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa3_pId'],
                               coords={'trainId': xgm_arr.trainId,
                                       'sa3_pId': np.arange(bpt_npulses)},
                               name=key.replace('XGM', 'SA3'))
        sa3_xgm = sa3_xgm.where(mask_sa3, drop=True)
        # remove potential corrupted data:
        sa3_xgm = sa3_xgm.where(sa3_xgm != -1., drop=True)
        xgm_dict[sa3_xgm.name] = sa3_xgm
    ds = xr.Dataset(xgm_dict)
    return ds
[docs]def calibrate_xgm(run, data, xgm='SCS', plot=False):
    """
    Calculates the calibration factor F between the photon flux (slow signal)
    and the fast signal (pulse-resolved) of the sase 3 pulses. The calibrated
    fast signal is equal to the uncalibrated one multiplied by F.
    Parameters
    ----------
    run: extra_data.DataCollection
        DataCollection containing the digitizer data.
    data: xarray Dataset
        dataset containing the pulse-resolved sase 3 signal, e.g. 'SCS_SA3'
    xgm: str
        one in {'XTD10', 'SCS'}
    plot: bool
        If True, shows a plot of the photon flux, averaged fast signal and
        calibrated fast signal.
    Returns
    -------
    F: float
        calibration factor F defined as:
        calibrated XGM [microJ] = F * fast XGM array ('SCS_SA3' or 'XTD10_SA3')
    Example
    -------
    >>> import toolbox_scs as tb
    >>> import toolbox_scs.detectors as tbdet
    >>> run, data = tb.load(900074, 69, ['SCS_XGM'])
    >>> ds = tbdet.get_xgm(run, merge_with=data)
    >>> F = tbdet.calibrate_xgm(run, ds, plot=True)
    >>> # Add calibrated XGM to the dataset:
    >>> ds['SCS_SA3_uJ'] = F * ds['SCS_SA3']
    """
    run_mnemonics = mnemonics_for_run(run)
    # check if bunch pattern table exists
    if 'bunchPatternTable' in data:
        bpt = data['bunchPatternTable']
    elif 'bunchPatternTable' in run_mnemonics:
        bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values())
    elif 'bunchPatternTable_SA3' in run_mnemonics:
        bpt = run.get_array(*run_mnemonics['bunchPatternTable_SA3'].values())
    else:
        raise ValueError('Bunch pattern missing. Cannot calibrate XGM.')
    mask_sa3 = is_sase_3(bpt.sel(trainId=data.trainId))
    npulses_sa3 = np.unique(mask_sa3.sum(dim='pulse_slot'))
    if len(npulses_sa3) == 1:
        npulses_sa3 = npulses_sa3[0]
    else:
        log.warning('change of pulse pattern in sase3 during the run.')
        npulses_sa3 = max(npulses_sa3)
    mask_sa1 = is_sase_1(bpt.sel(trainId=data.trainId))
    npulses_sa1 = np.unique(mask_sa1.sum(dim='pulse_slot'))
    if len(npulses_sa1) == 1:
        npulses_sa1 = npulses_sa1[0]
    else:
        log.warning('change of pulse pattern in sase1 during the run.')
        npulses_sa1 = max(npulses_sa1)
    pflux_key = f'{xgm}_photonFlux'
    if pflux_key in data:
        pflux = data[pflux_key]
    else:
        pflux = run.get_array(*run_mnemonics[pflux_key].values())
        pflux = pflux.sel(trainId=data.trainId)
    pflux_sa3 = (npulses_sa1 + npulses_sa3) * pflux / npulses_sa3
    avg_fast = data[f'{xgm}_SA3'].rolling(trainId=200).mean().mean(axis=1)
    calib = np.nanmean(pflux_sa3.values / avg_fast.values)
    if plot:
        plot_xgm_calibration(xgm, pflux, pflux_sa3, avg_fast, calib)
    return calib 
def plot_xgm_calibration(xgm, pflux, pflux_sa3, avg_fast, calib):
    plt.figure(figsize=(8, 4))
    plt.plot(pflux, label='photon flux all')
    plt.plot(pflux_sa3, label='photon flux SA3')
    plt.plot(avg_fast, label='avg pulsed XGM')
    plt.plot(avg_fast*calib, label='calibrated avg pulsed XGM')
    plt.title(f'calibrated XGM = {xgm}_SA3 * {calib:.3e}')
    plt.xlabel('train number')
    plt.ylabel(r'Pulse energy [$\mu$J]')
    plt.legend()
    return