import numpy as np
from math import ceil, floor, pi
from read_configurations import ReadConfigurationsList
import pele.angleaxis._otp_bulk as OTP
import distances
import sann

class converter(object):
    """ A class used by checkneighs objects for converting between rigid-body, atomistic cartesian and centre-of-mass
    coordinates.
    Note, at the moment the converter is hard-coded for Lewis-Wahnstrom OTP systems. Atomic systems such as BLJ are
    probably fine as well, provided appropriate checkneighs objects are used.
    The system-dependent functionality is only required when converting from rigid-body to atomistic coordinates.
    
    Parameters
    ----------
    nmol: integer
        The number of particles in the system (atoms or molecules)
    boxvec: array of floats
        np array containing the side lengths of the periodic simulation box (assumed to be orthorhombic)
    AA: logical
        Specifies whether the system expects coordinate arrays in rigid-body form (centre-of-mass + angle-axis)
        or simple atomistic cartesian coordinates.
        
    Attributes
    ----------
    system: pele system
        A system class to allow angle-axis coordinates to be converted into atomistic. At the moment, only OTP is coded
    """
        
    def __init__(self, nmol, boxvec, AA):
        self.nmol = nmol
        self.system = OTP.OTPBulk(nmol, boxvec, 1.0)
        self.AA = AA

    def to_at(self, x):
        """ Return a copy of coordinate array x in atomistic cartesian coordinates. Obviously, if we have an atomistic system
        then this is trivial. """
        if self.AA:  
            # Then we have CoM+AA coordinates, which can only be converted to cartesian by knowing the system-specific reference structure.
            x1 = self.system.aatopology.to_atomistic(x.flatten())
        else:
            # Then the coordinates are atomistic anyway
            x1 = x.reshape(-1,3)
        return x1

    def to_CoM(self,x):
        """ Return a copy of coordinate array x in centre-of-mass coordinates. """
        x = x.reshape(-1,3)
        if(self.AA):
            # The first half of the coordinates vector is CoM coords only.
            x1 = x[:self.nmol]  # This should be fine for CoM-only coords also.
        else: 
            # Atomistic coords: need to average over atom positions within a molecule
            natoms_per_mol = len(x)/self.nmol
            x1 = np.zeros((self.nmol,3))
            for i in xrange(self.nmol):
                for j in xrange(natoms_per_mol):
                    x1[i]+=x[i*natoms_per_mol+j]
                x1[i]/=natoms_per_mol
        return x1            

class checkneighs(object):
    """ Base class for all neighbour checking. 
    Different neighbour definitions work very differently so the only common functions are to do with storing neighbour 
    pairs once identified.
    """
    def addpair(self, neighs, i, j):
        """ Add a pair of indices (i,j) to the dictionary neighs, indicating that they are nearest neighbours.
        The keys of neighs are particle indices, and the value for each key is an array of indices of neighbouring particles. """
        if i in neighs.keys():  # Then particle i already has a neighbour array: add j on to the end.
            neighs[i]=np.append(neighs[i],j)
        else:  # Then j is the first neighbour of i to be found: initialise the neighbour array.
            neighs[i]=np.array([j])
            
        # Now the opposite: add j as a neighbour of i.
        if j in neighs.keys():
            neighs[j]=np.append(neighs[j],i)
        else:
            neighs[j]=np.array([i])
            
    def pad(self, neighs):
        """ Other classes will assume that neighs has a key for every particle in the system. If there are any particles 
        completely without neighbours, we must pad the dictionary with dummy entries. This shouldn't ever have any effect for a 
        sensible neighbour definition, as all particles will have multiple neighbours. """
        for i in xrange(self.nmol):
            if not neighs.has_key(i):
                neighs[i]=[]
        return neighs
            
class cutoff_checkneighs(checkneighs):
    """ Identify neighbours of a particle i as all particles within a specified distance of particle i.
    This global cutoff method is the usual way of identifying nearest neighbours in atomistic systems.
    This class acts as a base for several different cutoff-based checkneighs classes.
    
    Parameters
    ----------
    get_dist: function
        A metric for calculating the distance between two particles
    rcut: float
        The global cutoff distance for identifying nearest neighbours
    nmol: integer
        The number of particles in the system (atoms or molecules)
    boxvec: array of floats
        np array containing the side lengths of the periodic simulation box (assumed to be orthorhombic)
    AA: logical, default False
        Specifies whether the system expects coordinate arrays in rigid-body form (centre-of-mass + angle-axis)
        or simple atomistic cartesian coordinates.
    cell_list: cell_lists object, default None
        If provided, this object will be used to reduce the number of pair distances which must be tested
        
    Attributes
    ----------
    convert: converter object
        An object to convert coordinate arrays between different forms
    """
        
    def __init__(self, get_dist, rcut, nmol, boxvec,  AA=False, cell_list=None):
        self.nmol = nmol
        self.AA = AA  # Note, AA=True for CoM-only coords as well.

        self.get_dist = get_dist
        self.cell_list = cell_list
        if (cell_list is not None) and (np.any(self.cell_list.cell_size<self.rcut)):
            print "Warning: One of the dimensions in cell_list is smaller than rcut. Some neighbours will be missed."
        
        self.rcut = rcut

        self.convert = converter(nmol, boxvec, AA)
        
    def __call__(self, x):
        """ Take a configuration x of the system and return a dictionary neighs which contains lists of the nearest
        neighbours for every particle in the system. """
        neighs = {}          
        x = x.reshape(-1,3)
        
        if self.cell_list is not None: # Then we need only check neighbours in adjacent cells.
            # Work out which cell each particle is in
            self.cell_list.fill_cells(x)
            for a in self.cell_list.cells:
                for b in a.neighbours: # a.neighbours contains all the cells within rcut of a.
                    for i in a.particles:
                        for j in self.cell_list.cells[b].particles:
                            if i>j:  # To avoid adding each pair twice
                                if self.ispair(i, j, x[i],x[j]):
                                    self.addpair(neighs, i, j)          
        else:  # Then all pairs of molecules must be compared
            for i in xrange(self.nmol):
                for j in xrange(i+1,self.nmol): # Avoid double counting by setting the lower limit here
                    if self.ispair(i, j, x[i],x[j]):
                        self.addpair(neighs, i, j) 
        return self.pad(neighs)
    
    def ispair(self, i, j, xi, xj):
        """ Identify whether particles i and j qualify as a nearest-neighbour pair """
        r = self.get_dist(xi, xj)  # Measure the distance between each possible pair
        # Check that the two particles are different, and to see whether the neighbour distance criterion
        # is satisfied. We shouldn't ever be comparing an atom with itself, so the second check shouldn't be relevant.
        # If your potential does allow atoms to sit on top of one another, this ispair method will not work for you.         
        return (r<=self.rcut and r>1e-12)
                  
class binary_atomistic_cutoff(cutoff_checkneighs):
    """ An implementation of the global cutoff method for identifying neighbours, that works with a binary system.
    Coordinates arrays should contain coordinates of the A atoms first and the B atoms second.
    Currently only atomic systems are supported
    
    Parameters
    ----------
    get_dist: function
        A metric for calculating the distance between two atoms
    rcutAA: float
        The global cutoff distance for identifying when an AA pair of atoms are nearest neighbours
    rcutAB: float
        The global cutoff distance for identifying when an AB pair of atoms are nearest neighbours
    rcutBB: float
        The global cutoff distance for identifying when a BB pair of atoms are nearest neighbours                
    nmol: integer
        The number of atoms in the system
    ntypeA: integer
        The number of A-type atoms in the system
    boxvec: array of floats
        np array containing the side lengths of the periodic simulation box (assumed to be orthorhombic)
    cell_list: cell_lists object, default None
        If provided, this object will be used to reduce the number of pair distances which must be tested
        
    Attributes
    ----------
    types: list of integers
        A list with an entry for each atom, indexed in the same order as the coords array. Type A atoms have an
        entry of 0, type B have an entry of 1.
    rcuts: list of floats
        The three cutoff distances combined into one list for easy access.   
    """
            
    def __init__(self, get_dist, rcutAA, rcutAB, rcutBB, nmol, ntypeA, boxvec, cell_list=None):  
        self.nmol = nmol
        self.types = [0]*ntypeA + [1]*(nmol-ntypeA) # self.types[n] = 0 if n is type A and 1 if n is type B
        
        self.rcuts = [rcutAA, rcutAB, rcutBB]

        self.get_dist = get_dist
        self.cell_list = cell_list
        if (cell_list is not None) and (np.any(self.cell_list.cell_size<max(self.rcuts))):
            print "Warning: One of the dimensions in cell_list is smaller than the maximum rcut." 
            print "Some neighbours will be missed."
            
    def ispair(self, i, j, xi, xj):
        """ Measure the distance between the two atoms and compare it against the cutoff distance which is appropriate
        for their two types. """
        r = self.get_dist(xi,xj)
        # self.types[n] = 0 if n is type A and 1 if n is type B
        # self.types[i]+self.types[j] = 0 for an AA pair
        #                               1 for an AB pair
        #                               2 for a BB pair
        # Which corresponds to the entries in self.rcuts
        return (r<=self.rcuts[self.types[i]+self.types[j]] and r>1e-12)                          
                      
class CoM_cutoff(cutoff_checkneighs):
    """ A neighbour-checking scheme which is identical to the (unary) global cutoff method, but uses only centre-of-mass
    coordinates to determine the distance between two particles (which are assumed to be rigid bodies). """        
    def __call__(self,x):
        """ Provided the correct system class has been specified in the converter, this method can handle either CoM+AA
        coordinates, or atomistic cartesian coordinates. """
        x1 = self.convert.to_CoM(x)
        return super(CoM_cutoff,self).__call__(x1)              


class SANN(checkneighs):
    """ A checkneighs class using the SANN algorithm of van Meel, Filion, Valeriani and Frenkel, JCP 136 (2012).
    This method uses a non-local cutoff to identify nearest neighbours.
    The SANN algorithm is implemented as a Fortran shared object which takes an entire configuration as its
    argument and computes neighbour lists for all molecules in one go. So there is no need to iterate over pairs
    in this class.
    
    Parameters
    ----------
    nmol: integer
        The number of molecules in the system
    boxvec: array of floats
        numpy array containing side lengths for the periodic box
    rcut: float
        A fixed distance which is the maximum permitted value of the local cutoff distance for any given particle.
        Set it large enough that the local cutoffs chosen by the SANN algorithm never attempt to exceed rcut. 

    Attributes
    ----------
    neighlist: list of lists of integers
        Each element of neighlist refers to a different particle, in index order.
        Each element contains a neighbours list for the corresponding particle, using Fortran indexing (from 1)
    nb: list of integers
        Each element contains the number of neighbours for the corresponding particle

    """
    def __init__(self, nmol, boxvec, rcut, AA=False):
        self.nmol = nmol
        self.boxvec = boxvec
        self.AA = AA  # Note, AA=True for CoM-only coords as well.

        self.rcut = rcut    # Here, rcut is a simple fixed cutoff distance
                            # used to exclude obvious non-neighbours.

        self.convert = converter(nmol, boxvec, AA)
                           
    def __call__(self,x):      
        x1 = self.convert.to_CoM(x)
        self.sann_wrapper(x1)
        
    def sann_wrapper(self, x1):
        neighs = {}
        # npart is the number of particles in the system (atoms or molecules, it doesn't matter now that we've
        # transformed to CoM coordinates
        npart = x1.shape[0]

        # Use the Fortran implementation of the SANN algorithm to obtain the neighbours lists
        neighlist, nb = sann.neighbours(self.boxvec, x1.flatten(), self.rcut, npart)
        neighlist -= 1 # Convert from Fortran to Python indexing
        
        # Now copy the data across from neighlist into neighs        
        for i in xrange(npart):
            neighs[i] = np.array([])
            for j in xrange(nb[i]):
                neighs[i] = np.append(neighs[i],neighlist[i][j])

        return neighs

class SANN_atomistic(OTP_checkneighs):
    """ Another SANN-based checkneighs method. Here, two molecules are considered to be neighbours if any of
    their constitutent atoms are found to be neighbours by the SANN method operating on all atoms. """
    def __call__(self,x):        
        x1 = self.convert.to_at(x)
        self.sann_wrapper(x1)
        
    def sann_wrapper(self, x1):
        neighs = {} 
        natoms_per_mol = len(x1)/self.nmol

        # Compute neighbours using the SANN method based on fully atomistic coordinates
        neighlist, nb = sann.neighbours(self.boxvec, x1.flatten(), self.rcut, npart=x1.shape[0])
        neighlist -= 1 # Converting from Fortran to Python indexing        

        # Loop over all molecules
        for i in xrange(self.nmol):
            # Loop over all atoms
            for j in xrange(natoms_per_mol):
                # This gives us the index of the current atom within the atomistic coords array x1, and hence
                # the index in neighlist as well
                atom_index = natoms_per_mol*i+j
                # Loop over the neighbours of the atom we're currently considering
                for k in xrange(nb[atom_index]):
                    # Get the molecule to which the neighbouring atom belongs
                    neigh_index = neighlist[atom_index][k]/natoms_per_mol # Integer division: truncates
                    if i not in neighs.keys() or neigh_index not in neighs[i]:
                        self.addpair(neighs, i, neigh_index)

        return self.pad(neighs)   
              
class new_atomistic_SANN(SANN):
    ''' A checkneighs method that uses the SANN algorithm but records neighbours for each atom separately rather than per molecule. 
    Then the cage breaking definition must combine the neighbour changes for all the atoms according to some criterion.
    Because this method uses atomistic coordinates, it relies on a system-dependent coordinates converter.
    
    Parameters
    ----------
    nmol: integer
        The number of molecules in the system
    boxvec: array of floats
        numpy array containing side lengths for the periodic box
    rcut: float
        A fixed distance which is the maximum permitted value of the local cutoff distance for any given particle.
        Set it large enough that the local cutoffs chosen by the SANN algorithm never attempt to exceed rcut. 
    AA: logical
        Specifies whether the class expects coordinate arrays in rigid-body form (centre-of-mass + angle-axis)
        or simple atomistic cartesian coordinates.
        
    '''
    def __call__(self,x):        
        x1 = self.convert.to_at(x)
        self.sann_wrapper(x1)


class atomistic_SANN(SANN):
    ''' A checkneighs method that uses the SANN algorithm but records neighbours for each atom separately rather than per molecule. 
    Then the cage breaking definition must combine the neighbour changes for all the atoms according to some criterion.
    Because this method uses atomistic coordinates, it relies on a system-dependent coordinates converter '''
    def __init__(self, nmol, boxvec, rcut, AA=False):
        super(atomistic_SANN,self).__init__(nmol, boxvec, rcut, AA)
#        self.system = OTP.OTPBulk(self.nmol, self.boxvec, 1.0)

    def __call__(self,x):
        neighs = {}          
        x1 = self.convert.to_at(x)
        npart = x1.shape[0]

        neighlist, nb = sann.neighbours(self.boxvec, x1.flatten(), self.rcut, npart)
        neighlist -= 1 # Converting from Fortran to Python indexing        
        for i in xrange(npart):
            if i not in neighs.keys():
                neighs[i] = np.array([])
#            print i, nb[i], neighlist[i][:nb[i]]
            for j in xrange(nb[i]):
                neighs[i] = np.append(neighs[i],neighlist[i][j])
                # Symmetrise the definition
#                if neighlist[i][j] not in neighs.keys():
#                    neighs[neighlist[i][j]] = np.array(i)
#                elif i not in neighs[neighlist[i][j]]:
#                    neighs[neighlist[i][j]] = np.append(neighs[neighlist[i][j]],i)

        return neighs
