#############################################################################
# Author: parenti
# Created on April 5, 2024, 07:37 AM
# Copyright (C) European XFEL GmbH Schenefeld. All rights reserved.
#############################################################################
from karabo.middlelayer import (
    Configurable, DaqDataType, Device, EncodingType, Image, ImageData, Node,
    OutputChannel, UInt16)
from ._version import version as deviceVersion
def channelSchema(shape, encoding, dtype):
    class DataNode(Configurable):
        daqDataType = DaqDataType.TRAIN
        image = Image(
            displayedName="Image",
            shape=shape,
            encoding=encoding,
            dtype=dtype)
    class ChannelNode(Configurable):
        data = Node(DataNode)
    return ChannelNode
[docs]class ImageSource(Device):
    """
    Base class for image sources.
    It provides two output channels - 'output' and 'daqOutput' - for sending
    out images, and three functions - 'update_output_schema', 'write_channels'
    and 'signal_eos'.
    The function 'update_output_schema' will update the schema for the output
    channels and make it fit for the DAQ.
    The function 'write_channels' will write the input data to both the
    output channels, taking care of reshaping them for the DAQ.
    The function 'signal_eos' will send an end-of-stream signal to both the
    output channels.
    """
    # provide version for classVersion property
    __version__ = deviceVersion
    INITIAL_SHAPE = [1, 1]
    INITIAL_ENCODING = EncodingType.GRAY
    INITIAL_DTYPE = UInt16
    output = OutputChannel(
        channelSchema(INITIAL_SHAPE, INITIAL_ENCODING, INITIAL_DTYPE),
        displayedName="Output")
    # Second output channel for the DAQ
    daq_shape = list(reversed(INITIAL_SHAPE))
    daqOutput = OutputChannel(
        channelSchema(daq_shape, INITIAL_ENCODING, INITIAL_DTYPE),
        displayedName="DAQ Output")
    def __init__(self, configuration):
        super().__init__(configuration)
    def get_current_shape(self):
        shape = self.getDeviceSchema().hash.getAttribute(
            'output.schema.data.image.dims', 'defaultValue')
        return list(shape)
    def get_current_encoding(self):
        return self.getDeviceSchema().hash.getAttribute(
            "output.schema.data.image.encoding", "defaultValue")
    def get_current_type(self):
        return self.getDeviceSchema().hash.getAttribute(
            "output.schema.data.image.pixels.type", "defaultValue")
[docs]    async def update_output_schema(self, shape, encoding, dtype):
        """
        Update the schema of 'output' and 'daqOutput' channels
        :param shape: the shape of image, e.g. (height, width)
        :param encoding: the encoding of the image. e.g. EncodingType.GRAY
        :param dtype: the data type, e.g. UInt16
        :return:
        """
        if (shape == self.get_current_shape()
                and encoding == self.get_current_encoding()
                and dtype == self.get_current_type()):
            # Nothing to be done
            return
        output_schema = channelSchema(shape, encoding, dtype)
        daq_shape = list(reversed(shape))
        daq_output_schema = channelSchema(daq_shape, encoding, dtype)
        await self.setOutputSchema(
            "output", output_schema, "daqOutput", daq_output_schema) 
[docs]    async def write_channels(
            self, data, binning=None, bpp=None, encoding=None,
            roi_offsets=None, timestamp=None):
        """
        Write an image to 'output' and 'daqOutput' channels
        :param data: the image data as numpy.ndarray
        :param binning: the image binning, e.g. (1, 1)
        :param bpp: the bits-per-pixel, e.g. 12
        :param encoding: the image encoding, e.g. EncodingType.GRAY
        :param roi_offsets: the ROI offset, e.g. (0, 0)
        :param timestamp: the image timestamp - if none the current timestamp
            will be used
        :return:
        """
        image_data = ImageData(
            data, binning=binning, bitsPerPixel=bpp, encoding=encoding,
            roiOffsets=roi_offsets)
        self.output.schema.data.image = image_data
        await self.output.writeData(timestamp)
        # Reshape image for DAQ
        # NB DAQ wants shape in CImg order, eg (width, height)
        data = data.reshape(*reversed(data.shape))
        image_data = ImageData(
            data, binning=binning, bitsPerPixel=bpp, encoding=encoding,
            roiOffsets=roi_offsets)
        self.daqOutput.schema.data.image = image_data
        await self.daqOutput.writeData(timestamp) 
[docs]    async def signal_eos(self):
        """
        Send an end-of-stream signal to 'output' and 'daqOutput' channels
        :return:
        """
        await self.output.writeEndOfStream()
        await self.daqOutput.writeEndOfStream()