import numpy as np
import distances

class reversal_manager(object):
    """ Callable class to identify and record reversal chains in a sequence of cage breaks.
    This is a base class, and should be used to derive classes with particular definitions of reversed cage breaks
    by overriding the check_direct_reversal and check_indirect_reversal functions.

    The object should be called repeatedly, each time passing in a new cage break object (see checkcb.py).
    If the particle which moves in this cage break has not previously undergone a cage break, then this cage
    break is stored and nothing further occurs.
    Otherwise, the new cage break is checked against the previous saved cage break. If no reversal is detected,
    then the previous cage break is returned (it's chain having now terminated). Otherwise, the old cage break is
    discarded (the reversal chain is still ongoing) and the new one is saved.
    
    Parameters
    ----------
    terminated_only: Logical, default True
        Default behaviour is to return only the last cage break in each chain. If False, all cage breaks are returned
    verbose: logical, default False
        If true, detailed information about the analysis of each cage break is printed to stdout
        
    Attributes
    ----------
    live_chains: dict of cage breaks
        A dictionary containing the last cage break object which was given to the reversal manager for each particle.
        Keys are particle indices, values are cage break objects.
    """
    def __init__(self, terminated_only=True, verbose=False):
        self.live_chains = {}
        self.term_only = terminated_only
        self.verbose = verbose    
    
    def __call__(self, cb):
        """ Determine whether the current cage break constitutes a reversal of the previous one.
        If there was no previous cage break, the current cb is simply stored.
        If the current cb is a direct reversal of the previous cb, then the earlier cb is discarded and the current
        one stored. The reversal chain length is incremented by one.
        If the current cb reveals that the previous cb was reversed by non-cage-breaking motion, the reversal chain
        length of the previous cb is incremented by one but the chain is also terminated. So the old cb is returned
        and the current one stored. To distinguish this type of reversal, we multiply the final "l" value (i.e. chain
        length) by -1. This assists with book-keeping.
        If the current cb is completely unrelated to the old cb, the chain is terminated (so the old cb is returned)
        and the current cb is stored as the first step in the new chain.
        """
        if not self.live_chains.has_key(cb.index):
            if self.verbose:
                print "Adding first cb to chain with index", cb.index
            self.live_chains[cb.index] = cb.copy()
            return None

        # Extract the old cage break from storage, preparatory to passing it out.
        old_cb = self.live_chains[cb.index].copy()
        
        reversal = self.check_direct_reversal(old_cb, cb)
        
        if reversal:  # Reversal chain is still running
            # Add one to the length of the current reversal chain
            cb.l += 1
            # Save the current cb
            self.live_chains[cb.index] = cb.copy()                
            if self.term_only: # Default behaviour: don't report the old cb at all, since it was reversed.
                return None
            else: # Return every cb regardless of where it falls in the reversal chain
                return old_cb
            
        reversal = self.check_indirect_reversal(old_cb, cb)
        
        if reversal: # Previous cb was reversed by an indirect cb. The chain is terminated.
            # Previous cb was reversed (again)
            old_cb.l += 1
            # To distinguish chains detected by indirect reversals, we multiply the chain length by -1
            old_cb.l *= -1
            # Save the current cb: the first in the new chain
            self.live_chains[cb.index] = cb.copy()  
            return old_cb
        
        # If we haven't returned already, the previous cage break has not been reversed so the chain terminates.
        # Save the current cb: the first in the new chain
        self.live_chains[cb.index] = cb.copy()
        return old_cb        
                    
    def check_direct_reversal(self, oldcb, cb):
        """ Determine whether the current cb is a direct reversal of the old cb """
        raise AttributeError("Base class reversal_manager has no implementation for check_direct_reversal")

    def check_indirect_reversal(self, oldcb, cb):
        """ Determine whether the old cb was reversed by non-cage-breaking motion """
        raise AttributeError("Base class reversal_manager has no implementation for check_indirect_reversal")

class distance_reversal_manager(reversal_manager):
    """ Reversal manager derived class which identifies reversals according to the displacement method set out
    in de Souza and Wales, JCP 129 (2008).
    Note, this reversal method may be used with cage breaks defined for both atomic and molecular systems.
   
    Parameters
    ----------
    boxvec: array of floats
        Numpy array containing the side lengths of the periodic simulation box, which must be orthorhombic
    return_threshold: float
        Threshold distance to determine when cage breaks are directly reversed (see check_direct_reversal)
    duplicate_threshold:
        Threshold distance to determine when cage breaks are indirectly reversed (see check_indirect_reversal)
        Default behaviour is to use the value of return_threshold    
    terminated_only: logical, default True
        Default behaviour is to return only the last cage break in each chain. If False, all cage breaks are returned
    verbose: Logical, default False
        If true, detailed information about the analysis of each cage break is printed to stdout
        
    Attributes
    ----------
    live_chains: dict of cage breaks
        A dictionary containing the last cage break object which was given to the reversal manager for each particle.
        Keys are particle indices, values are cage break objects
    dist: distance measure function
        A distance metric from CageBreaks.distances, in this case get_dist_periodic, to measure the distance between two points
    """    
    def __init__(self, boxvec, return_threshold, duplicate_threshold=None, terminated_only=True, verbose=False):
        super(reversal_manager, self).__init__(terminated_only=terminated_only, verbose=verbose)
        self.return_threshold = return_threshold
        if duplicate_threshold is None:
            self.duplicate_threshold = return_threshold
        else:
            self.duplicate_threshold = duplicate_threshold
                    
        self.dist = CageBreaks.distances.get_dist_periodic(boxvec)    
            
    def check_direct_reversal(self, oldcb, cb):
        """ If the square net displacement over 2 cbs is less than the threshold, the old cb was reversed directly. """
        return (np.linalg.norm(self.dist(oldcb.old_pos, cb.new_pos)) <= self.thresh)
        
    def check_indirect_reversal(self, oldcb, cb):
        """ If the square net displacement over 2 cbs matches the square displacement of either of the two cbs individually,
        the first cb was reversed by non-cage-breaking moves. """ 
        net_displacement = np.linalg.norm(self.dist(oldcb.old_pos, cb.new_pos))
        displacement1 = np.linalg.norm(self.dist(oldcb.old_pos, oldcb.new_pos))
        if np.abs(net_displacement-displacement1) <= self.duplicate_threshold:
            return True
        else:
            displacement2 = np.linalg.norm(self.dist(oldcb.new_pos, oldcb.new_pos))
            if np.abs(net_displacement-displacement2) <= self.duplicate_threshold:
                return True
            else:
                return False
            
class neighbours_reversal_manager(reversal_manager):
    """ Reversal manager derived class which identifies reversals according to changes in the nearest-neighbour lists.
    We compare the lists of nearest neighbours of the central atom which change in the two cage breaks. If many neighbour
    changes from the first cb are reversed in the second, a direct reversal has taken place. If many neighbour changes
    are duplicated between the two cage breaks, an indirect reversal has taken place. 
   
    Parameters
    ----------
    reversal_threshold: integer
        Threshold number of neighbour changes which must be reversed in order that we consider the cage break to be reversed.
        The same threshold is used for direct reversals and duplicate changes (indirect reversals) 
    terminated_only: logical, default True
        Default behaviour is to return only the last cage break in each chain. If False, all cage breaks are returned
    verbose: Logical, default False
        If true, detailed information about the analysis of each cage break is printed to stdout
        
    Attributes
    ----------
    live_chains: dict of cage breaks
        A dictionary containing the last cage break object which was given to the reversal manager for each particle.
        Keys are particle indices, values are cage break objects
    """
    
    def __init__(self, reversal_threshold, terminated_only=True, verbose=False):
        super(reversal_manager, self).__init__(terminated_only=True, verbose=False)
        self.thresh = reversal_threshold
            
    def check_direct_reversal(self, oldcb, cb):
        """ If the number of neighbours that were lost in the old cb and regained in the new cb exceeds the threshold,
        the old cb was reversed directly. Similarly if neighbours were gained in the old cb and lost again in the new. """
        # Set up lists of neighbours which appear in the changed_neighbours list for both cage breaks.
        self.reversed_neighs = [[],[]]  # The first element contains neighbours which were lost and then regained.
                                        # The second contains neighbours which were gained and then lost
        self.duplicates = [[],[]]   # First element: neighbours which were lost and then lost again
                                    # Second element: neighbours which were gained twice

        # We fill in both lists at the same time to avoid making two passes over the changed_neighbours lists.        
        for i in oldcb.changed_neighbours:
            for j in cb.changed_neighbours:
                if(i==-j): # Neighbour changed in the opposite sense the second time
                    if(i<0): # Neighbour was lost in the old cb, gained in the new
                        self.reversed_neighs[0].append(i)
                    else: # Neighbour was gained in the old cb, lost in the new
                        self.reversed_neighs[1].append(i)
                if(i==j): # The neighbour changed in the same sense the second time
                    if(i<0): # Neighbour was lost twice
                        self.duplicates[0].append(i)
                    else: # Neighour was gained twice
                        self.duplicates[1].append(i)
        if self.verbose:
            print "reversed_neighs", self.reversed_neighs
            print "duplicates", self.duplicates
                
        # If the number of reversed neighbour changes exceeds the threshold, we have an overall reversed cb.
        if (len(self.reversed_neighs[0])>=self.thresh or len(self.reversed_neighs[1])>=self.thresh):
            if self.verbose:     
                print "CB reversed, l:", cb.l
                print "reversed neighbours:", self.reversed_neighs
            return True 
        
    def check_indirect_reversal(self, oldcb, cb):
        """ If the number of neighbours that were lost in the old cb and lost again in the new cb exceeds the threshold,
        the old cb was reversed indirectly. Similarly if neighbours were gained twice. """
        # We already filled in the duplicates list, so this second test is quick and easy.
        # If the number of reversed neighbour changes exceeds the threshold, we have an overall reversed cb.
        if (len(self.duplicates[0])>=self.thresh or len(self.duplicates[1])>=self.thresh):    
            if self.verbose:
                print "CB reversed by non-CB process, l:", oldcb.l
                print "duplicate neighbours:", self.duplicates
            return True

class molecular_reversal_manager(reversal_manager):
    """ Reversal manager for cage breaks in a system of rigid body molecules.
    This class uses the "site-cb" methodology of Niblett, de Souza, Stevenson and Wales, JCP (to be published, 2016). 
    A molecule is deemed to have undergone a directly/indirectly reversed cage break when all or some of its sites 
    undergo a directly/indirectly reversed cb.
    
    Parameters
    ----------
    reversal_threshold: integer/float
        Threshold parameter which is interpreted differently by the different reversal methods. See inline comments.
    napm: integer
        The number of sites ("atoms") in each rigid body ("molecule")
    terminated_only: Logical, default True
        Default behaviour is to return only the last cage break in each chain. If False, all cage breaks are returned
    verbose: logical, default False
        If true, detailed information about the analysis of each cage break is printed to stdout
        
    Attributes
    ----------
    live_chains: dict of cage breaks
        A dictionary containing the last cage break object which was given to the reversal manager for each particle.
        Keys are particle indices, values are cage break objects
    """
    def __init__(self, reversal_threshold, napm, terminated_only=True, verbose=False, method=1):

        super(molecular_reversal_manager, self).__init__(terminated_only=terminated_only, verbose=verbose)
        self.napm = napm
        self.thresh = reversal_threshold
        self.method = method
       
    def check_direct_reversal(self, oldcb, cb):    
        # As in neighbours_reversal_manager, we must make a list of the neighbours which are regained/relost etc. in the 2nd cb.
        # The difference is, now we must keep a separate list of neighbour changes for each site in the molecule.
        self.reversed_neighs = [[] for j in range(2*self.napm)]     # [Lost then gained, gained then lost] for each atom
        self.duplicates = [[] for j in range(2*self.napm)]          # [Lost twice, gained twice] for each atom
        self.remaining = [[] for j in range(2*self.napm)]                # Number of neighbour changes which remain after removing all
                                                                    # the reversed and duplicate changes

        # Check each site-cb to see whether it was reversed.
        for k in xrange(self.napm):
            # Compare the two changed_neighbours lists for the corresponding site (k)
            for i in self.live_chains[cb.index].changed_neighbours[k]:
                # The comparison is the same as for neighbours_reversal_manager
                for j in cb.changed_neighbours[k]:
                    if(i==-j):
                        if(i<0):
                            self.reversed_neighs[2*k].append(i)
                        else:
                            self.reversed_neighs[2*k+1].append(i)
                        break
                    elif(i==j):
                        if(i<0):
                            self.duplicates[2*k].append(i)
                        else:
                            self.duplicates[2*k+1].append(i)
                        break
                    
            # We wish to know how many neighbour changes from oldcb's list survive, after removing
            # any neighbour changes which are reversed or duplicated in the current cb. We put this data into self.remaining.
            for i in self.live_chains[cb.index].changed_neighbours[k]:
                if ((i not in self.reversed_neighs[2*k]) and (i not in self.reversed_neighs[2*k+1])):
                    if i<0:
                        self.remaining[2*k].append(i)
                    else:
                        self.remaining[2*k+1].append(i)

        if self.verbose:
            print "reversed_neighs", self.reversed_neighs
            print "duplicates", self.duplicates
            print "remaining after removing reversals", self.remaining

        # Start by assuming that there is a reversal. If we find that any site has not undergone a reversed site-cb,
        # we change this variable to False
        reversal=True 

        # For this method, we record separately whether each site-cb is reversed or not.
        # This is because we need to start by assuming that every cb is unreversed, then we will change elements
        # of atom_cb to False when a particular site-cb is disqualified by reversed neighbour changes.
        atom_cb = [True]*self.napm  

        for k in xrange(self.napm):
            # Count reversals when removing the reversed neighbours is enough to disqualify the CB
            # Note, this is equivalent to the original definition of a site-cb, but using only the neighbour changes
            # which remain after reversed changes are removed.
            if (len(self.remaining[2*k])<self.thresh and len(self.remaining[2*k+1])<self.thresh):
                atom_cb[k]=False

        if False not in atom_cb: 
            # Every site-cb still qualifies, so no overall reversal.
            reversal = False

        return reversal

    def check_indirect_reversal(self, oldcb, cb):
        
        dup_reversal=True
        atom_cb = [True]*self.napm

        # Have to reset the "remaining" list so that we can check for a reversal by ncb processes
        self.remaining = [[] for j in range(2*self.napm)]
        for k in range(self.napm):
            for i in self.live_chains[cb.index].changed_neighbours[k]:
                if (i not in self.duplicates[2*k]) and (i not in self.duplicates[2*k+1]):
                    if i<0:
                        self.remaining[2*k].append(i)
                    else:
                        self.remaining[2*k+1].append(i)
        atom_cb = [True]*self.napm 

        if self.verbose:
            print "Remaining after removing duplicates:", self.remaining

        for k in xrange(self.napm):
            if (len(self.remaining[2*k])<self.thresh and len(self.remaining[2*k+1])<self.thresh):
                atom_cb[k]=False

        if False not in atom_cb:
                dup_reversal = False

        return dup_reversal
