"""
    DSSC-detector class module
    The dssc detector class. It represents a namespace for frequent evaluation
    while implicitly applying/requiring certain structure/naming conventions to
    its objects.
    comments:
        - contributions should comply with pep8 code structure guidelines.
        - Plot routines don't fit into objects since they are rather fluent.
          They have been outsourced to dssc_plot.py. They can now be accessed
          as toolbox_scs member functions.
"""
import os
import logging
import joblib
import numpy as np
import xarray as xr
import toolbox_scs as tb
from ..util.exceptions import ToolBoxValueError, ToolBoxFileError
from .dssc_data import (
    save_xarray, load_xarray, save_attributes_h5,
    search_files, get_data_formatted)
from .dssc_misc import (
        load_dssc_info, get_xgm_formatted, get_tim_formatted)
from .dssc_processing import (
        process_dssc_data, create_empty_dataset)
__all__ = [
    "DSSCBinner",
    "DSSCFormatter"]
log = logging.getLogger(__name__)
[docs]class DSSCBinner:
    def __init__(self, proposal_nr, run_nr,
                 binners={},
                 xgm_name='SCS_SA3',
                 tim_names=['MCP1apd', 'MCP2apd', 'MCP3apd'],
                 dssc_coords_stride=2,
                 ):
        """
        A dssc binner object. Loads and bins the dssc data according to the
        bins specified in 'binners'. The data can be reduced further through
        masking using XGM or TIM data.
        Parameters
        ----------
        proposal_nr: int, str
            proposal number containing run folders
        run_nr: int, str
            run number
        binners: dictionary
            dictionary containing binners constructed using the
            'create_dssc_bins' toolbox_scs.detectors-method.
        xgm_name: str
            a valid mnemonic key of the XGM data to be used to mask the dssc
            frames. Since the xgm is used in several methods its name can be
            set here globally.
        tim_names: list of strings
            a list of valid mnemonic keys for an mcp in the tim. Once the
            corresponding data is loaded the different sources will be averaged.
        dssc_coords_stride: int, list
            defines which dssc frames should be normalized using data from the
            xgm. The parameter may be an integer (stride parameter) or a list,
            that assigns each pulse to its corresponding dssc frame number.
        Returns
        -------
        DSSCbinner: object
        Example
        -------
        1.) quick -> generic bins, no xgm,
        >>> import toolbox_scs as tb
        >>> run235 = tb.DSSCBinner(proposal_nb=2212, run_nb=235)
        2.) detailed -> docs
        """
        # ---------------------------------------------------------------------
        # object (run) properties
        # ---------------------------------------------------------------------
        self.proposal = proposal_nr
        self.runnr = run_nr
        self.info = load_dssc_info(proposal_nr, run_nr)
        self.run, _ = tb.load(proposal_nr, run_nr)
        self.binners = {}
        for b in binners:
            self.add_binner(b, binners[b])
        self.xgm_name = xgm_name
        self.tim_names = tim_names
        self.dssc_coords_stride = dssc_coords_stride
        self.xgm = None
        self.tim = None
        self.pulsemask = None
        log.debug("Constructed DSSC object")
[docs]    def __del__(self):
        pass 
[docs]    def add_binner(self, name, binner):
        """
        Add additional binner to internal dictionary
        Parameters
        ----------
        name: str
            name of binner to be created
        binner: xarray.DataArray
            An array that represents a map how the respective coordinate should
            be binned.
        Raises
        ------
        ToolBoxValueError: Exception
            Raises exception in case the name does not correspond to a valid
            binner name. To be generalized.
        """
        if name in ['trainId', 'pulse', 'x', 'y']:
            self.binners[name] = binner
        else:
            msg = "Invalid binner name"
            log.info(msg+", no binner created")
            raise ToolBoxValueError(msg, name) 
[docs]    def load_xgm(self):
        """
        load xgm data and construct coordinate array according to corresponding
        dssc frame number.
        """
        self.xgm = get_xgm_formatted(self.run,
                                     self.xgm_name,
                                     self.dssc_coords_stride) 
[docs]    def load_tim(self):
        """
        load tim data and construct coordinate array according to corresponding
        dssc frame number.
        """
        self.tim = get_tim_formatted(self.proposal,
                                     self.runnr,
                                     self.tim_names,
                                     self.dssc_coords_stride) 
[docs]    def create_pulsemask(self, use_data='xgm', threshold=(0, np.inf)):
        """
        creates a mask for dssc frames according to measured xgm intensity.
        Once such a mask has been constructed, it will be used in the data
        reduction process to drop out-of-bounds pulses.
        """
        fpt = self.info['frames_per_train']
        n_trains = self.info['number_of_trains']
        trainIds = self.info['trainIds']
        data = np.ones([n_trains, fpt], dtype=bool)
        self.pulsemask = xr.DataArray(data,
                                      dims=['trainId', 'pulse'],
                                      coords={'trainId': trainIds,
                                              'pulse': range(fpt)})
        if use_data == 'xgm':
            if self.xgm is None:
                self.load_xgm()
            valid = (self.xgm > threshold[0]) * \
                    
(self.xgm < threshold[1])
        if use_data == 'tim':
            if self.tim is None:
                self.load_tim()
            valid = (self.tim > threshold[0]) * \
                    
(self.tim < threshold[1])
        self.pulsemask = \
            
(valid.combine_first(self.pulsemask).astype(bool))[:, 0:fpt]
        log.info(f'created pulse mask used during processing') 
[docs]    def get_info(self):
        """
        Returns the expected shape of the binned dataset, in case binners have
        been defined.
        """
        if any(self.binners):
            empty = create_empty_dataset(self.info, self.binners)
            return(empty.dims)
        else:
            log.info("no binner defined yet.")
            pass 
[docs]    def get_xgm_binned(self):
        """
        Bin the xgm data according to the binners of the dssc data. The result
        can eventually be merged into the final dataset by the DSSCFormatter.
        Returns
        -------
        xgm_data: xarray.DataSet
            xarray dataset containing the binned xgm data
        """
        if self.xgm is not None:
            xgm_data = self.xgm.to_dataset(name='xgm')
            xgm_binned = self._bin_metadata(xgm_data)
            log.info('binned xgm data according to dssc binners.')
            return xgm_binned
        else:
            log.warning("no xgm data. Use load_xgm() to load the xgm data.")
            pass 
[docs]    def get_tim_binned(self):
        """
        Bin the tim data according to the binners of the dssc data. The result
        can eventually be merged into the final dataset by the DSSCFormatter.
        Returns
        -------
        tim_data: xarray.DataSet
            xarray dataset containing the binned tim data
        """
        if self.tim is not None:
            tim_data = self.tim.to_dataset(name='tim')
            tim_binned = self._bin_metadata(tim_data)
            log.info('binned tim data according to dssc binners.')
            return tim_binned
        else:
            log.warning("no data. Use load_tim() to load the tim data.")
            pass 
    # -------------------------------------------------------------------------
    # Data processing
    # -------------------------------------------------------------------------
[docs]    def process_data(self, modules=[], filepath='./',
                     chunksize=512, backend='loky', n_jobs=None,
                     dark_image=None,
                     xgm_normalization=False, normevery=1
                     ):
        """
        Load and bin dssc data according to self.bins. No data is returned by
        this method. The condensed data is written to file by the worker
        processes directly.
        Parameters
        ----------
        modules: list of ints
            a list containing the module numbers that should be processed. If
            empty, all modules are processed.
        filepath: str
            the path where the files containing the reduced data should be
            stored.
        chunksize: int
            The number of trains that should be read in one iterative step.
        backend: str
            joblib multiprocessing backend to be used. At the moment it can be
            any of joblibs standard backends: 'loky' (default),
            'multiprocessing', 'threading'. Anything else than the default is
            experimental and not appropriately implemented in the dbdet member
            function 'bin_data'.
        n_jobs: int
            inversely proportional of the number of cpu's available for one
            job. Tasks within one job can grab a maximum of n_CPU_tot/n_jobs of
            cpu's.
            Note that when using the default backend there is no need to adjust
            this parameter with the current implementation.
        dark_image: xarray.DataArray
            DataArray with dimensions compatible with the loaded dssc data. If
            given, it will be subtracted from the dssc data before the binning.
            The dark image needs to be of dimension module, trainId, pulse, x
            and y.
        xgm_normalization: boolean
            if true, the dssc data is normalized by the xgm data before the
            binning.
        normevery: int
            integer indicating which out of normevery frame will be normalized.
        """
        log.info("Bin data according to binners")
        log.info(f'Process {chunksize} trains per chunk')
        mod_list = modules
        if len(mod_list) == 0:
            mod_list = [i for i in range(16)]
        log.info(f'Process modules {mod_list}')
        njobs = n_jobs
        if njobs is None:
            njobs = len(mod_list)
        module_jobs = []
        for m in mod_list:
            dark = dark_image
            if dark_image is not None:
                dark = dark_image.sel(module=m)
            module_jobs.append(dict(
                proposal=self.proposal,
                run_nr=self.runnr,
                module=m,
                chunksize=chunksize,
                path=filepath,
                info=self.info,
                dssc_binners=self.binners,
                pulsemask=self.pulsemask,
                dark_image=dark,
                xgm_normalization=xgm_normalization,
                xgm_mnemonic=self.xgm_name,
                normevery=normevery,
            ))
        log.info(f'using parallelization backend {backend}')
        joblib.Parallel(n_jobs=njobs, backend=backend)\
            
(joblib.delayed(process_dssc_data)(**module_jobs[i])
             for i in range(len(mod_list)))
        log.info(f'Binning done')