from cbrule import cage_break

class checkcb(object):
    """Class to identify atoms which undergo a cage break between two configurations of the system.

    This is a manager class: the actual rule for identifying CBs must be defined separately as a cbrule
    object and passed to this class. The checkcb object requires a checkneighs object to identify 
    the nearest neighbours of a particular particle.
    If reversals are to be analysed, this is done by defining a reversal_manager object which is given to 
    the cbrule.
   
    Parameters
    ----------
    checkneighs: checkneighs object
        A callable object which takes a configuration and identifies the nearest neighbours of all atoms 
        in the system
    cbrule: cbrule object
        An object which takes two lists of neighbours for an atom and identifies whether the atom has 
        undergone a cage break
               
    Attributes
    ----------
    chain: logical
        Specifies whether configurations being passed in are given as part of a chain, or as discrete 
        pairs of configurations
    
    Returns
    -------
    cbs: dict
        A dictionary with keys corresponding to atoms which underwent a cage break in this pair of configurations
        The corresponding dictionary values are cage_break objects recording details of the transition.
    """

    def __init__(self, checkneighs, cb_rule):
        self.checkneighs = checkneighs
        self.cb_rule = cb_rule

        self.chain = False  # Indicates whether or not configurations are going to be passed in as a chain
        
    def start_chain(self, x0):
        """ Indicate that configurations will be added one-at-a-time, as a chain. 
        Each time a configuration is added, we wish to see whether a cage break has taken place since 
        the previous configuration that was saved.
        For this first configuration, we calculate and save the list of neighbours but do nothing else.
        """
        self.chain = True

        self.oldneighs = self.checkneighs(x0) 
        # Can't check for a cage break yet, since this is the first configuration
        
    def __call__(self, x0, x1, time=0):
        """ Determine which atoms undergo cage breaks between configurations x0 and x1.
        The output of this method is a dictionary of cage_break objects (or None if no CBs took place) """
                
        cbs = {'frame':time}
        
        # Compute the dictionary of particle neighbours in the new configuration
        neighs = self.checkneighs(x1)
        if self.chain:  # then we already calculated the neighbours dictionary for the old configuration
            oldneighs = self.oldneighs
        else:  # then we need to calculate the neighbours for the old configuration as well.
            oldneighs = self.checkneighs(x0)
        
        # This loop runs over all particle indices
        for i in oldneighs.keys():
            # Determine whether or not this particle undergoes a cage break between these two configurations
            cb = self.cb_rule(i, x0, x1, oldneighs[i], neighs[i], time)
            if cb is not None: # Then a cb took place for this atom: save it.
                cbs[i] = cb
                
        if self.chain:
            # We will re-use this neighbours dictionary, so save it.
            self.oldneighs = neighs.copy()
            
        # The output from this method is a dictionary of all the cage breaks which took place in this step.
        return cbs           
        
    def flush(self):
        """ If the cb_rule has a reversal manager set, it may not be returning every cage break that takes
        place (see the documentation for reversal_manager). In this case, when we are done passing in configurations
        we will need to flush out any cage breaks which belong to unterminated reversal chains - otherwise we would
        not count them. """
        if self.cb_rule.reversal_manager is None: # Then all CBs have already been returned: no need to flush
            return None
        else:
            cbs = {'frame':'terminate'}
            try:
                while(True):
                    # Extract the final cage break from each unterminated chain and add to the cbs dictionary.
                    i, cb = self.cb_rule.reversal_manager.live_chains.popitem()
                    cbs[i] = cb
            except KeyError:  # This means we have extracted all the unterminated cage breaks
                return cbs  # Return the dictionary as normal

class check_for_single_cb(checkcb):
    """ A checkcb method which returns after the first cb is identified, rather than checking every pair of atoms.
    This is intended for use with stationary-point databases, to assess how the connectivity of the landscape is
    influenced by cage breaks.
    
    Parameters
    ----------
    checkneighs: checkneighs object
        A callable object which takes a configuration and identifies the nearest neighbours of all atoms 
        in the system
    cbrule: cbrule object
        An object which takes two lists of neighbours for an atom and identifies whether the atom has 
        undergone a cage break
               
    Attributes
    ----------
    chain: logical
        Specifies whether configurations being passed in are given as part of a chain, or as discrete 
        pairs of configurations
    
    Returns
    -------
    cbs: dict
        A dictionary with keys corresponding to atoms which underwent a cage break in this pair of configurations
        The corresponding dictionary values are cage_break objects recording details of the transition.    
    """
    def __call__(self, x0, x1, time=0):
        """ Identify whether any atoms undergo a cage break between the configurations x0 and x1. """
        # The first part of this function is the same as in checkcb               
        cbs = {'frame':time}
        neighs = self.checkneighs(x1)      
        if self.chain:
            oldneighs = self.oldneighs
        else:
            oldneighs = self.checkneighs(x0)            
    
        for i in oldneighs.keys():
            cb = self.cb_rule(i, x0, x1, oldneighs[i], neighs[i], time)
            if cb is not None:
                cbs[i] = cb
                # Unlike in checkcb, we return from the function the first time we find a cb.
                if self.chain:
                    self.oldneighs = neighs.copy()
                return cbs        
            
        # If we've got this far, there were no cbs in this step and so we return cbs with only one key ('frame')
        if self.chain:
            self.oldneighs = neighs.copy()
        return cbs   

class atomistic_checkcb(checkcb):
    """A version of checkcb designed for (rigid-body) molecules. 
    Neighbours and cage breaks are tracked independently for every site (atom) in the molecule.
    A cage-break occurs only when all atoms in the molecule undergo a site-cage break.
    This class incorporates the functionality of checkcb and check_for_single_cb.
    
    Parameters
    ----------
    checkneighs: checkneighs object
        A callable object which takes a configuration and identifies the nearest neighbours of all atoms 
        in the system. Note, although this class is intended for molecular systems the neighbours must returned 
        atom-by-atom rather than molecule-by-molecule.
    cbrule: cbrule object
        An object which takes two lists of neighbours for an atom and identifies whether the atom has 
        undergone a cage break
    nmol: integer
        The number of molecules in the system
    molecular_reversal_manager: object, default None
        An object which uses records of site cbs to determine whether a newly-identified molecular cage-break 
        is in fact a reversal of the previous cb. Default is to return all identified cbs.
    first_only: logical, default False
        When True, the class stops checking each pair of configurations the first time it identifies a cage-break.
               
    Attributes
    ----------
    chain: logical
        Specifies whether configurations being passed in are given as part of a chain, or as discrete 
        pairs of configurations
    
    Returns
    -------
    cbs: dict
        A dictionary with keys corresponding to atoms which underwent a cage break in this pair of configurations
        The corresponding dictionary values are cage_break objects recording details of the transition.      
    """
    def __init__(self, checkneighs, cb_rule, nmol, molecular_reversal_manager=None, first_only=False):
        super(atomistic_checkcb,self).__init__(checkneighs, cb_rule)
        self.nmol = nmol
        self.first_only=first_only
        
        if self.cb_rule.reversal_manager is not None:
            self.cb_rule.reversal_manager.term_only = False

        self.mol_rev_manager = molecular_reversal_manager   # Because the cbrule reversal manager keeps track only of atomic CBs, 
                                                            # we need a separate manager to determine when the whole molecule has reversed

        self.live_chains = {}  
                              

    def __call__(self, x0, x1, time=0):
        """ Identify whether any atoms undergo a cage break between the configurations x0 and x1. 
        If the coordinates are in CoM+AA form, the checkneighs object must have a converter which knows to 
        convert these coordinates to cartesian atomistic format.
        If the coordinates are in atomistic form already, checkneighs.convert must know not to convert them. 
        Recall that checkneighs must return the neighbour dicts for every atom, not just every molecule.
        """
                
        cbs = {'frame':time}

        # Calculate the neighbour dicts for the two configurations.
        neighs = self.checkneighs(x1)
        if self.chain:
            oldneighs = self.oldneighs
        else:
            oldneighs = self.checkneighs(x0)
        
        # Obtain atomistic coordinates
        x0_at = self.checkneighs.convert.to_at(x0)
        x1_at = self.checkneighs.convert.to_at(x1)
        x0 = x0.reshape(-1,3)
        x1 = x1.reshape(-1,3)

        # Next, a bit of sanity-checking

        # The array x0_at contains 3 coordinates per atom
        natoms = x0_at.size/3
        if natoms%self.nmol!=0:
            print "Problem: atomistic_checkcb is expecting ", self.nmol, "molecules but coords array suggests there are", natoms, "atoms"
        napm = natoms/self.nmol # number of atoms per molecule

        # We assume that every atom has at least one neighbour - but this may not be true of all neighbours definitions.
        if len(oldneighs.keys())!=natoms:
            print "Problem with old configuration: there are ", len(oldneighs.keys()), "atoms with neighbours and ", natoms, "atoms in total"
        if len(neighs.keys())!=natoms:
            print "Problem with new configuration: there are ", len(neighs.keys()), "atoms with neighbours and ", natoms, "atoms in total"

        # Now look for the actual cage breaks. Check each molecule in turn and identify molecular cage breaks
        for i in xrange(self.nmol):
            # This list will hold cage_break objects for any site-cbs which are identified.
            atom_cb = [None]*napm
            # Check each atom in the molecule
            for j in xrange(napm):
                try:
                    # Apply the site-cb_rule to each atom position.
                    atom_cb[j] = self.cb_rule(i*napm+j, x0_at, x1_at, oldneighs[i*napm+j], neighs[i*napm+j], time)
                except KeyError:
                    print "Atom ", i*napm+j, "is missing neighbours in one or other configuration"

                if atom_cb[j] is None:
                    # If one of the atoms has not broken its cage, then the molecule cannot have undergone a 
                    # molecular cb. So we move on to the next molecule.   
                    break	

                if j==napm-1:      
                    # If we got to the last atom without breaking out of this loop, then all the atoms had a cb.
                    # That means the molecule underwent a cb as well.
                    print "Cage break for molecule ", i

                    # Use the cage_break objects saved in atom_cb to construct a list of all the site-neighbour
                    # changes which took place during this cage break.
                    changed = [None]*napm
                    for k in xrange(napm):
                        # Each element of changed corresponds to a list of site-neighbour changes.
                        changed[k] = atom_cb[k].changed_neighbours

                    # Construct a cage_break object for the whole molecule
                    cb = cage_break(i, x0[i], x1[i], changed, time)

                    # If we're only interested in the presence or absence of cbs during this transition, there's
                    # no need to check any other molecules for cage breaks. Return now.
                    if self.first_only:
                        if self.chain:
                            self.oldneighs = neighs.copy()
                        cbs[i] = cb
                        return cbs

                    # If a reversal manager has been set, use it to identify whether this cb belongs to an extant
                    # reversal chain. The result will be returned.
                    if self.mol_rev_manager is None:
                        cbs[i] = cb
                    else:
                        cb = self.mol_rev_manager(cb)
                        if cb is not None: # Don't want to add a lot of spurious None objects.
                            cbs[i] = cb
 
        # Save data for the next configuration in the chain.
        if self.chain:
            self.oldneighs = neighs.copy()
        return cbs

    def flush(self):
        """ This is the usual flush routine, but now applied to the molecular reversal manager instead of the
        atomic version. """
        if self.mol_rev_manager is None:
            return None
        else:
            cbs = {'frame':'terminate'}
            try:
                while(True):
                    i, cb = self.mol_rev_manager.live_chains.popitem()
                    cbs[i] = cb
            except KeyError:
                return cbs  


        
