"""
Defines class CleftRegions that holds cleft-related data (geometry and layers)
from one or more observations (experiments) divided (classified) in groups.

The observations are expected to be generated by scripts/cleft.py.
or classify_connections.py.  

# Author: Vladan Lucic (Max Planck Institute for Biochemistry)
# $Id$
"""
from __future__ import unicode_literals
from __future__ import absolute_import
from __future__ import division
from builtins import range
#from past.utils import old_div

__version__ = "$Revision$"


import warnings
import logging
from copy import copy, deepcopy

import numpy
import scipy

import pyto
from ..util import nested
from .observations import Observations
from .groups import Groups


class CleftRegions(Groups):
    """
    Modes:
      - 'layers': layers made on the whole cleft region
      - 'layers_cleft': layers made on segments detected in the cleft
      - 'columns': columns made on the whole cleft region
    
    """

    ###############################################################
    #
    # Initialization
    #
    ##############################################################

    def __init__(self, mode=None):
        """
        Initializes attributes.

        Defines following attributes (all set of strings):
          - _full_properties: names of properties read from the pickles, may
          include '.' (if a desired property is an attribute of a class that is
          saved as an attribute of pickled CleftRegions object)
          - properties: attribute names of of this instance, where the above
          properties are to be stored
          - _full_indexed: names of indexed properties of the pickles, may 
          include '.'
          - indexed: attribute names of of this instance that correspond to
          this instance 
        """

        # initialize super 
        super(CleftRegions, self).__init__()

        # determines the conversion of property names
        self._deep = 'last'

        # mode
        self._mode = mode

        # definitions used in read()
        if (mode == 'layers') or (mode == 'layers_cleft'):
            
            # layers and layers_cleft modes
            self._full_properties = set(
                ['regions.ids', 'width', 'widthVector.phiDeg', 
                 'widthVector.thetaDeg', 'minCleftDensityId', 
                 'minCleftDensityPosition', 'relativeMinCleftDensity', 
                 'regionDensity.mean', 'regionDensity.std', 
                 'regionDensity.min', 'regionDensity.max', 
                 'regionDensity.volume', 'boundThick'])
            self._full_indexed = set(
                ['regions.ids', 'regionDensity.mean', 'regionDensity.std', 
                 'regionDensity.min', 'regionDensity.max', 
                 'regionDensity.volume'])

        elif mode == 'layers_on_columns':
            
            # layers and layers_cleft modes
            self._full_properties = set(
                ['regions.ids', 'minCleftDensityId', 
                 'minCleftDensityPosition', 'relativeMinCleftDensity', 
                 'regionDensity.mean', 'regionDensity.std', 
                 'regionDensity.min', 'regionDensity.max', 
                 'regionDensity.volume', 'boundThick'])
            self._full_indexed = set(
                ['regions.ids', 'regionDensity.mean', 'regionDensity.std', 
                 'regionDensity.min', 'regionDensity.max', 
                 'regionDensity.volume'])

        elif mode == 'columns':

            # columns mode
            self._full_properties = set(
                ['regions.ids', 'regionDensity.mean', 'regionDensity.std', 
                 'regionDensity.min', 'regionDensity.max', 
                 'regionDensity.volume'])
            self._full_indexed = set(
                ['regions.ids', 'regionDensity.mean', 'regionDensity.std', 
                 'regionDensity.min', 'regionDensity.max', 
                 'regionDensity.volume'])

        elif mode is None:

            self._full_properties = set([])
            self._full_indexed = set([])

        # get full attribute names 
        self._properties = set(
            [pyto.util.attributes.get_deep_name(attr, mode=self._deep)
             for attr in self._full_properties])
        self._indexed = set(
            [pyto.util.attributes.get_deep_name(attr, mode=self._deep)
             for attr in self._full_indexed])


    ###############################################################
    #
    # Input
    #
    ##############################################################

    @classmethod
    def read(cls, files, catalog, mode=None, reference=None, 
             categories=None, order=None):
        """
        Reads one or more scene.CleftRegions pickles specified by arg files and
        puts all of them in a new instance of this class.

        Each pickle contains data from a single experiment. The properties
        read are specified by attribute self._full_properties.

        In addition, reads other data corresponding to the experiments from 
        arg catalog and puts them together with the data from pickles. The
        only obligatory property is 'pixel_size'.

        If ids for an observation is None, all indexed properties (specified 
        in self._indexed) for that observation are set to empty arrays.

        Arg reference is used only in mode 'layers_on_colums'. It specifies 
        another instance of this class that should be used for the density 
        normalization by normalizeByMean(mode='0to1').

        A category specified by arg categories, or an experiment 
        identifier specified by arg order that does not exist in the data
        (arg files) is ignored and a warning is generated. This condition
        often generates an exception at a later point.

        Arguments:
          - files: dictionary of cleft regions result pickle files
          - catalog: (Catalog) data about experiments
          - mode: cleft regions mode, 'layers', 'layers_cleft', 'columns',
          or 'layers_on_columns'
          - categories: list of categories
          - order: another Groups instance (or just a dictionary with group 
          names as keys and identifier lists as values), used only to define 
          the order of identifiers in this instance

        Sets properties:
          - identifiers: identifiers
          - ids: ids
          - width, width_nm: cleft width in pixels and nm
          - phiDeg, thetaDeg: cleft orientation (angles phi, theta in degrees)
          - minCleftDensityId: id of min density cleft layer
          - minCleftDensityPosition: relative position of the min density
          cleft layer (1/n_cleft_layers for the cleft layer closest to the 
          first boundary, 1 - 1/n_cleft_layers for the cleft layer closest
          to the second boundary
          - relativeMinCleftDensity: relative layer density of the cleft layer
          with min density (0 if the same as mean boundary density, 1 if the 
          same as mean cleft density)
          - mean/std/min/max: layer density mean/std/min/max
          - volume, volume_nm: layer density volume in pixels^3, nm^3
          - cleftIds: list of layer ids that belong to the cleft
          - boundIds: list of layer ids that belong to the boundary 1
          - bound1Ids: list of layer ids that belong to the boundary 1
          - bound2Ids: list of layer ids that belong to the boundary 2
          - boundThick: boundary thickness (in number of layers)
          - angleToYDeg: absolute value of the angle between the cleft (phi,
          theta is assumed to be 90 deg) and the y axis
          - all properties set in catalog files
          - normalMean: normalized mean density. If mode is 'layers' the
          mean density is normalized so that the mean of boundary values is 0
          and the mean of cleft valuies is 1. If mode is 'layers_cleft', the
          absoulte normalization is used to set the mean cleft density to 1.

          ToDo: remove non-cleft layers in layers_cleft mode?
          """

        # initialize
        db = pyto.io.Pickled(files)
        inst = cls(mode=mode)

        # use all categories if not specified
        if categories is None:
            categories = list(db.categories())

        # loop over categories
        for categ in categories:

            # check if data for the current category exist 
            logging.debug('CleftRegions: Reading group ' + categ)
            if categ not in list(db.categories()):
                logging.warning(
                    'CleftRegions: Data for group ' + categ + ' do not exist')

            # make sure the identifier order is the same
            if order is not None:
                if isinstance(order[categ], Observations):
                    identifier = order[categ].identifiers
                elif isinstance(order[categ], (list, tuple)):
                    identifier = order[categ]
            else:
                identifier = None
            
            # check if requested identifiers exist in the database
            if identifier is not None:
                clean = []
                for requested in identifier:
                    if requested in db.identifiers():
                        clean.append(requested)
                    else:
                        logging.warning(
                            'CleftRegions: Data for experiment ' + requested + 
                            ' do not exist')
                identifier = clean

            # get data
            observ = Observations()
            for observ, obj, categ_tmp, name_tmp in db.readPropertiesGen(
                category=categ, identifier=identifier, deep=inst._deep, 
                properties=inst._full_properties, index='regions.ids', 
                indexed=inst._full_indexed, multi=observ):

                logging.debug('Read data of experiment ' + name_tmp) 

                # extract cleft and boundary ids
                if (mode == 'layers') or (mode == 'layers_cleft'):
                    observ.setValue(property='cleftIds', identifier=name_tmp, 
                                    value=obj.cleftLayerIds)
                    observ.setValue(property='boundIds', identifier=name_tmp, 
                                    value=obj.boundLayerIds)
                    observ.setValue(property='bound1Ids', identifier=name_tmp, 
                                    value= obj.bound1LayerIds)
                    observ.setValue(property='bound2Ids', identifier=name_tmp, 
                                value=obj.bound2LayerIds)
                elif mode == 'layers_on_columns':
                    observ.setValue(property='cleftIds', identifier=name_tmp, 
                                    value=obj.cleftLayerIds)

            # add data for this category
            inst[categ] = observ

            # set array properties to empty arrays for observations without ids
            for obs_index in range(len(inst[categ].identifiers)):
                if inst[categ].ids[obs_index] is None:
                    for name in inst._indexed:
                        value = getattr(inst[categ], name)
                        value[obs_index] = numpy.array([])

            # set book-keeping attributes
            inst[categ].index = 'ids'
            inst[categ].indexed.update(inst._indexed)
            #inst[categ].properties = inst._properties

            # add properties from catalog 
            inst[categ].addCatalog(catalog=catalog)

        # calculate additional data properties
        inst.calculateProperties()

        # convert to nm
        if mode is not None:
            inst.convertToNm(catalog=catalog)

        # calculate mode dependent data properties
        if mode == 'layers':
            inst.normalizeByMean(name='mean', region=['bound', 'cleft'],
                                  mode='0to1', categories=categories)
        elif (mode == 'layers_on_columns') and (reference is not None):
            inst.normalizeByMean(
                name='mean', region=['bound', 'cleft'], mode='0to1', 
                categories=categories, reference=reference)
        elif mode == 'layers_cleft':
            inst.normalizeByMean(name='mean', mode='absolute', region='cleft',
                                 categories=categories)
            inst.normalizeByMean(name='volume', region=['bound', 'cleft'],
                                  mode='0to1', categories=categories)
        elif (mode == 'columns') and (reference is not None):
            inst.normalizeByMean(
                name='mean', region=['bound', 'cleft'], mode='0to1', 
                categories=categories, reference=reference)

        return inst


    ###############################################################
    #
    # Data modifying methods
    #
    ##############################################################

    def calculateProperties(self, categories=None):
        """
        Calculates additonal properties. 

        Sets following new properties to each group (Observations instance)
        contained in this object:
          - angleToYDeg: absolute value of the angle between the cleft (phi,
          theta is assumed to be 90 deg) and the y axis
          - minCleftDensityId: (modes 'layers', 'layers_cleft' and 
          'layers_on_columns')
          - minCleftDensityPosition: (modes 'layers' and 'layers_cleft')

        Argument:
          - categories: list of group names, in None all groups are used
        """

        if categories is None:
            categories = list(self.keys())

        for categ in categories:
            for ident in self[categ].identifiers:

                # angle with y axis
                if 'phiDeg' in self[categ].properties:
                    phi = self[categ].getValue(
                        identifier=ident, property='phiDeg')
                    alpha = numpy.abs(numpy.mod(phi, 180) - 90)
                    self[categ].setValue(identifier=ident, 
                                         property='angleToYDeg', value=alpha)

                # cleft density position
                if (((self._mode == 'layers') or (self._mode == 'layers_cleft') 
                    or (self._mode == 'layers_on_columns')) 
                    and ('minCleftDensityId' in self[categ].properties)):
                    min_id = self[categ].getValue(identifier=ident, 
                                                  property='minCleftDensityId')
                    min_pos = self[categ].getValue(
                        identifier=ident, property='minCleftDensityPosition')
                    if min_id is None:
                        continue
                    if len(min_id) > 1:
                        logging.warning(
                            "Experiment " + ident + " of group " + categ + "has"
                            + " more than one cleft density minimum position. "
                            + "Taking the mean value.")
                        self[categ].setValue(
                            identifier=ident, value=min_id.mean(),
                            property='minCleftDensityId')
                        self[categ].setValue(
                            identifier=ident, value=min_pos.mean(),
                            property='minCleftDensityPosition')
                    else:
                        self[categ].setValue(identifier=ident, value=min_id[0],
                                             property='minCleftDensityId')
                        self[categ].setValue(identifier=ident, value=min_pos[0],
                                             property='minCleftDensityPosition')

    def normalizeByMean(self, name, normalName=None, mode='relative', 
                        region=None, ids=None, reference=None, categories=None):
        """
        Normalizes indexed property specified by arg name based on mean
        value(s) of subset(s) of that property values. 

        Args region and ids detemine how the mean(s) is (are) calculated.
        If arg region is given, the mean(s) of the values corresponding to that 
        region(s) are used for the normalization. If arg region is 
        None, the values corresponding to ids are used to calculate the mean(s).
        Either region or ids should be specified.

        Arg mode determines how the normalization is preformed. If it is 
        'absolute', the absoulte difference between the property values and the 
        mean is calculated. If it is 'relative', the relative difference
        (values - mean / mean) is calculated.

        If arg mode is 0to1, two regions need to be specified. This can be done
        by specifying two regions in arg region (list of length 2), or by
        sepcifying arg ids as a list of length 2 where each element of ids is 
        a list or an ndarray. Two mean values (for each element of region or 
        ids) are calculated (called mean_0 and mean_1). The values are 
        normalized by:

          (values - mean_0) / (mean_1 - mean_0)

        Specifying arg reference allows the use of another object to calculate 
        mean_0 and mean_1. For example if reference=[object_a, object_b], 
        object_a will be used to calculate mean_0 and object_b for mean_b. A
        reference object has to be an instane of this class and it needs to 
        have data corresponding to regions or ids specified in args. If 
        one of these objects is None, this object will be used instead. If 
        reference=None, this object will be used for both references.

        The normalized values are saved with name given by arg normalName. If
        this arg is None, the new name is 'normal' + name.capitalize()

        Arguments:
          - name: property name
          - normalName: name of the normalized property
          - mode: normalization mode, 'absolute', 'relative' or '0to1'
          - region: specifies regions used to caluclate the mean, currently 
          'cleft', 'bound', or 'cleft&bound'
          - ids: directly specifies ids used to caluclate the mean
          - reference: specifies other object(s) (of this class) to be used 
          for the calculation of mean values. None to use this object. Applies 
          only to mode='0to1'.
          - categories: categories (groups)

        Sets property containing normalized values.
        """

        if categories is None:
            categories = list(self.keys())

        # figure out if regions specified by ids (or by region) arg
        if region is None:
            by_ids = True
            if ids is None:
                all_ids = True
            else:
                all_ids = False
        else:
            by_ids = False

        # put region or ids in a list
        if (mode == 'relative') or (mode == 'absolute'):
            if by_ids:
                ids = [ids]
            else:
                region = [region]

        # figure out references for 'absolute' and 'relative' modes
        if (mode == 'relative') or (mode == 'absolute'):
            if reference is None:
                ref_obj = [self]

        # figure out references for '0to1' mode
        if reference is None:
            ref_obj = [self, self]
        elif isinstance(reference, (list, tuple)):
            ref_obj = [None, None]
            if reference[0] is None:
                ref_obj[0] = self
            else:
                ref_obj[0] = reference[0]
            if reference[1] is None:
                ref_obj[1] = self
            else:
                ref_obj[1] = reference[1]
        else:
            ref_obj = [reference, reference]

        # normalize
        for categ in categories:
            for ident in self[categ].identifiers:

                # initialize ids if needed
                if (not by_ids) or all_ids:
                    ids = []

                for index in [0, 1]:
 
                    # get ids corresponding to region(s) 
                    if by_ids:
                        if all_ids:
                            ids.append([ref_obj[index][categ].getValue(
                                        property='ids', identifier=ident)])
                    elif region[index] == 'cleft':
                        ids.append(ref_obj[index][categ].getValue(
                                property='cleftIds', identifier=ident))
                    elif region[index] == 'bound':
                        if mode == '0to1':
                            ids.append(ref_obj[index][categ].getValue(
                                    property='boundIds', identifier=ident))
                        else:
                            ids.append(ref_obj[index][categ].getValue(
                                    property='boundIds', identifier=ident))
                    elif region[index] == 'cleft&bound':
                        cleft_ids = ref_obj[index][categ].getValue(
                            property='cleftIds', identifier=ident)
                        if mode == '0to1':
                            bound_ids = ref_obj[index][categ].getValue(
                                property='boundIds', identifier=ident)
                        else:
                            bound_ids = ref_obj[index][categ].getValue(
                                property='boundIds', identifier=ident)
                        ids.append(numpy.concatenate(cleft_ids, bound_ids))
                    else:
                        ValueError(
                            "Argument region not understood. Acceptable values"
                            + " are None, 'cleft', 'bound' and 'cleft&bound'.")
                        
                    if (mode == 'relative') or (mode == 'absolute'):
                        break
                    elif mode == '0to1':
                        pass
                    else:
                        raise ValueError(
                            "Argument mode: " + mode + " not understood. "
                            + "Acceptable values are 'relative', 'absolute' "
                            + "and '0to1'.")

                # get all values
                values = self[categ].getValue(property=name, identifier=ident)

                # normalize
                if mode == 'relative':
                    region_values = ref_obj[index][categ].getValue(
                        property=name, identifier=ident, ids=ids[0])
                    mean = region_values.mean()
                    normalized = (values - mean) / float(mean)
                elif mode == 'absolute':
                    region_values = ref_obj[index][categ].getValue(
                        property=name, identifier=ident, ids=ids[0])
                    mean = region_values.mean()
                    normalized = values - mean

                elif mode == '0to1':

                    # claculate normalizations
                    values_0 = ref_obj[0][categ].getValue(
                        property=name, identifier=ident, ids=ids[0])
                    values_1 = ref_obj[1][categ].getValue(
                        property=name, identifier=ident, ids=ids[1])
                    mean_0 = values_0.mean()
                    mean_1 = float(values_1.mean())
                    normalized = (values - mean_0) / (mean_1 - mean_0)

                # set normalized
                if normalName is None:
                    normalName = 'normal' + name.capitalize()
                self[categ].setValue(
                    property=normalName, identifier=ident, value=normalized,
                    indexed=True)
                   
    def getRelative(self, fraction, new, name='mean', region=None, ids=None, 
                    weight=None, categories=None):
        """
        Calculates a value relative to two reference values. The reference 
        values are obtained from property given as arg name at regions 
        specified by arg region. The new valueis saved as a new property 
        named (arg) new.

        The new value is calculated using arg fraction as follows:

          region_0 + fraction * (region_1 - region_0)

        where region_0 and region_1 are the reference values, that is the 
        values of property specified by arg name.

        If arg weight is None, the references are calculated as a simple mean
        of the values of property name for all ids comprising the 
        corresponding regions. Otherwise, arg weight should be the name of
        the property used to weight the average. For example, mean greyscale
        density may be weighted by volume.

        Typically used to find values between cleft and boundary densities.

        If arg regions in None, regions are specified by arg ids.

        Arguments:
          - fraction: fraction
          - new: name of the newly calculated property
          - name: name of the property used as a refern
          - region: list of two regions
          - ids: list of length 2 where each element is a list of ids
          - weight: weight used to calculate the mean region values
          - categories: categories

        Sets:
          - property new
        """

        # set categories if not specified
        if categories is None:
            categories = list(self.keys())

        # figure out if regions specified by ids (or by region) arg
        if region is None:
            by_ids = True
            if ids is None:
                all_ids = True
            else:
                all_ids = False
        else:
            by_ids = False

        for categ in categories:

            for ident in self[categ].identifiers:

                # initialize ids if needed
                if (not by_ids) or all_ids:
                    ids = []

                # get ids corresponding to region(s) 
                for index in [0, 1]:
 
                    if by_ids:
                        if all_ids:
                            ids.append([self[categ].getValue(
                                        property='ids', identifier=ident)])
                    elif region[index] == 'cleft':
                        ids.append(self[categ].getValue(
                                property='cleftIds', identifier=ident))
                    elif region[index] == 'bound':
                        ids.append(self[categ].getValue(
                                property='boundIds', identifier=ident))
                    elif region[index] == 'cleft&bound':
                        cleft_ids = self[categ].getValue(
                            property='cleftIds', identifier=ident)
                        bound_ids = self[categ].getValue(
                            property='boundIds', identifier=ident)
                        ids.append(numpy.concatenate(cleft_ids, bound_ids))
                    else:
                        ValueError(
                            "Argument region not understood. Acceptable values"
                            + " are None, 'cleft', 'bound' and 'cleft&bound'.")

                # get values
                values_0 = self[categ].getValue(
                        property=name, identifier=ident, ids=ids[0])
                values_1 = self[categ].getValue(
                        property=name, identifier=ident, ids=ids[1])

                # get means
                if weight is None:
                    mean_0 = values_0.mean()
                    mean_1 = float(values_1.mean())
                else:
                    weight_0 = self[categ].getValue(
                        property=weight, identifier=ident, ids=ids[0])
                    weight_1 = self[categ].getValue(
                        property=weight, identifier=ident, ids=ids[1])
                    mean_0 = (
                        (values_0 * weight_0).sum() / float(weight_0.sum()))
                    mean_1 = (
                        (values_1 * weight_1).sum() / float(weight_1.sum()))

                # set 
                value = mean_0 + (mean_1 - mean_0) * fraction
                self[categ].setValue(property=new, identifier=ident, 
                                     value=value)

    def convertToNm(self, catalog, categories=None):
        """
        Converts certain properties from pixels to nm. The new values are
        assigned to (new) properties named by adding '_nm' to the corresponding
        original property name.

        Converted properties are: 
          - width_nm
          - volume_nm
          - surface_nm: (for mode 'layers')
        """

        if categories is None:
            categories = list(self.keys())

        for categ in categories:
            pixel = catalog.pixel_size

            # width
            if (self._mode == 'layers') or (self._mode == 'layers_cleft'):
                self[categ].width_nm = self[categ].pixels2nm(
                    name='width', conversion=pixel[categ])
                self[categ].properties.add('width_nm')

            # volume
            self[categ].volume_nm = self[categ].pixels2nm(
                name='volume', power=3, conversion=pixel[categ])
            self[categ].properties.add('volume_nm')
            self[categ].indexed.add('volume_nm')
            
            # surface
            if ((self._mode == 'layers') or (self._mode == 'layers_cleft') 
                or (self._mode == 'layers_on_columns')):
                try:
                    self[categ].surface_nm = self[categ].pixels2nm(
                        name='volume', conversion=pixel[categ], power=2)
                    self[categ].properties.update(['surface_nm'])
                    self[categ].indexed.update(['surface_nm']) 
                except TypeError:
                    if ((self[categ].volume is None) 
                        or any(value is None for value in self[categ].volume)):
                        pass
                    else:
                        raise
                    
    def getBoundarySurfaces(self, names, surface='surface', categories=None, 
                            factor=1):
        """
        Calculates the surfaces of the boundary layers that are adjacent to the
        cleft layers. Only for 'layers' and 'layers_cleft' modes.

        Layer surfaces have to be already specified as property named (arg) 
        surface. This value is then multiplied by (arg) factor to get the
        final value

        Arguments:
          - names: (list of 2 strings) property names where the values of the 
          two calculated surfaces are saved
          - surface: property name that contains layer surfaces
          - categories: categories
          - factor: multiplicative factor

        Sets:
          - properties named by elements of the arg names, where names[0] is 
          used for the boundary layer with smalled index
        """
        
        # check the mode
        if (self._mode != 'layers') and (self._mode != 'layers_cleft'):
            return

        # set categories
        if categories is None:
            categories = list(self.keys())

        # calculate and set ids
        for categ in categories:
            for ident in self[categ].identifiers:
                
                # get ids from cleft ids
                cleft_ids = self[categ].getValue(identifier=ident, 
                                                 property='cleftIds')
                bound1_id = min(cleft_ids) - 1
                bound2_id = max(cleft_ids) + 1

                # check if the same is obtained from bound ids
                other_bound1_id = self[categ].getValue(
                    identifier=ident, property='bound1Ids').max()
                if bound1_id != other_bound1_id:
                    raise ValueError("The layer index of the last bound1 layer"
                                     + " could not be determined.")
                other_bound2_id = self[categ].getValue(
                    identifier=ident, property='bound2Ids').min()
                if bound2_id != other_bound2_id:
                    raise ValueError("The layer index of the first bound2 layer"
                                     + " could not be determined.")

                # get surface
                surface_1 = self[categ].getValue(
                    property=surface, ids=bound1_id, identifier=ident)
                surface_1 = surface_1 * factor
                surface_2 = self[categ].getValue(
                    property=surface, ids=bound2_id, identifier=ident)
                surface_2 = surface_2 * factor

                # set values
                self[categ].setValue(property=names[0], value=surface_1,
                                     identifier=ident)
                self[categ].setValue(property=names[1], value=surface_2,
                                     identifier=ident)

