import numpy as np

class cell_lists(object):
    """ A class to set up and update cell lists for general systems in periodic boundary conditions.
    The cell list is used to reduce the number of pairs which must be compared when computing a pairwise potential
    or doing a nearest-neighbour analysis, by ensuring that only nearby pairs of atoms/molecules are compared.
    The simulation box is divided into small cells, and when looking for pairs that interact with a particular particle
    we only consider particles in the same cell or neighbouring cells. So long as the cell width is at least as large
    as the cutoff for the potential/neighbour method, this method should improve the efficiency considerably without
    changing the result of the analysis.
    
    Parameters
    ----------
    boxvec: array of floats
        numpy array containing the (3) dimensions of the periodic box (which must be orthorhombic).
    cutoff: float
        The side length of the (cubic) cells, should be at least as large as the cutoff for the potential being used.
            
    Attributes
    ----------
    
    ncells: array of floats
        numpy array containing the number of cells which are stacked along each of the 3 coordinate axes to make up the
        simulation box.
    cell_size: array of floats
        numpy array containing the dimensions of a cell
    cells: array of cells
        numpy array of cell objects which is the main data structure for this class.
    """

    def __init__(self, boxvec, cutoff):
        self.boxvec = boxvec
        
        self.ncells = np.floor(boxvec / cutoff)
        # self.cell_size should be equivalent to cutoff*np.ones(3)
        self.cell_size = boxvec/self.ncells
                       
        self.cells = np.empty(np.prod(self.ncells), dtype=cell)
        for i in xrange(self.cells.size):
            self.cells[i] = cell(i,self.ncells)
               

    def fill_cells(self, x):
        """ x is a np array containing all the particle coordinates in a configuration (any shape is acceptable).
        For a molecule, "particle coordinates" means centre-of-mass coordinates.
        This function sorts the particles into the various cells. The list of particles in cell [i] is accessed by
        self.cells[i].particles. """
    
        # Reshape the array so that each entry corresponds to a particle position vector
        x = x.reshape(-1,3)

        # Clear the existing lists of particles in each cell
        for i in xrange(self.cells.size):
            self.cells[i].particles = []

        # Consider each particle in turn and put it into the relevant cell
        for i in xrange(x.shape[0]):
            self.put_in_cell(i, x[i])
        

    def put_in_cell(self, particle_index, particle_x):
        """ particle_index and particle_x specify the "name" and position of a particle. This function identifies the
        correct cell for this particle, and adds it to the relevant cell list. """

        # Put the particle back in the periodic box, if it has somehow left
        particle_x -= np.round(particle_x/self.boxvec)*self.boxvec
        # The previous operation puts the origin at the centre of the simulation box. For the cell lists, we assume that
        # the origin is at one corner of the box. The following line corrects for the difference between these two coordinate
        # systems. NOTE: cell_lists is compatible with input generated using either convention for origin position. 
        particle_x += (self.boxvec/2.0)

        # Identify the coordinates of the (lower-left-corner of the) cell in which the particle resides.
        cell_coords = np.floor(particle_x/self.cell_size)
        # Generate the index of the cell from those coordinates.
        cell_index = int((cell_coords[0] + cell_coords[1]*self.ncells[0] + 
                 cell_coords[2]*self.ncells[0]*self.ncells[1]))

        # Add the particle index to the relevant list.
        self.cells[cell_index].add_particle(particle_index)        
        
class cell(object):
    """ A class to record the state of a cell in the simulation box.  
    Data we need to save includes the position of a cell (specified by its index), the list of particles it contains and
    the indices of neighbouring cells. If a particular particle is located in this cell, all particles closer than
    the cutoff will be contained within the same cell or the set of neighbouring cells. 
    
    Parameters
    ----------   
    index: integer
        Index number to identify the cell. This increases in integer steps, running first along the x axis, then y and 
        finally z. So cell 0 has one corner at the origin, cell 1 shares the '+x' face of cell 0, and so on.
    ncells: array of integers
        Numpy array holding the number of cells which fit into the simulation box along each of the 3 axes.

    Attributes
    ----------
    particles: list of integers
        A list of the particle indices which fall within this cell.
    neighbours: list of integers
        A list of cell indices for the cells which neighbour this cell.
 """


    def __init__(self, index, ncells):
        self.index = index
        self.particles = []
        self.neighbours = []
        self.ncells = ncells
        
        self.find_neighbours()
        
    def find_neighbours(self):
        """ Set up the list of cells which neighbour this one. The present notation is rather messy, I will hopefully get
        round to improving it at a later date. """

        # This is a list of how the cell index changes when we step to the next cell along each dimension.
        # cell_displacements[0] contains the change to the cell index when we move in the x direction by [-1*cellwidth,
        # 0*cellwidth and 1*cellwidth]. cell_displacements[1] corresponds to movement in the y direction, and similar for z.
        cell_displacements = [
                              [-1, 0, 1],
                              [-self.ncells[0], 0, self.ncells[0]],
                              [-self.ncells[0]*self.ncells[1], 0, self.ncells[0]*self.ncells[1]]
                             ]

        # If we lie on one of the edges of the simulation box, we must treat the wrapping that arises through periodic
        # boundary conditions.
        
        # If this lies on the plane x=0...
        if(self.index % self.ncells[0] == 0):
            # ...stepping in the -x direction will bring us to the high-x end of the simulation box
            cell_displacements[0][0] += self.ncells[0]
        # If this cell lies on x=xmax...
        if(self.index % self.ncells[0] == self.ncells[0] - 1):
            # ...stepping in the +x direction brings us to the x=0 plane.
            cell_displacements[0][2] -= self.ncells[0]
        if(self.index % (self.ncells[0]*self.ncells[1]) < self.ncells[0]): # y=0 plane
            cell_displacements[1][0] += self.ncells[0]*self.ncells[1]
        if(self.index % (self.ncells[0]*self.ncells[1]) >= self.ncells[0]*(self.ncells[1]-1)): # y=ymax
            cell_displacements[1][2] -= self.ncells[0]*self.ncells[1]          
        if(self.index < (self.ncells[0]*self.ncells[1])):  # z=0 plane
            cell_displacements[2][0] += self.ncells[0]*self.ncells[1]*self.ncells[2]       
        if(self.index >= (self.ncells[0]*self.ncells[1]*(self.ncells[2] - 1))): #z=zmax
            cell_displacements[2][2] -= self.ncells[0]*self.ncells[1]*self.ncells[2]

        # Now that we have set up the effect on the index of stepping in each possible direction, we combine these displacements
        # In the 9 possible ways (one of which is to stay in the same place) to build up a list of the 9 cells which need to
        # be checked if we want to find the neighbours of a particle in the centre box.
        for i in cell_displacements[0]:
            for j in cell_displacements[1]:
                for k in cell_displacements[2]:
                    # For each combination of displacements, we take the present index and add on the three displacements.
                    self.neighbours.append(self.index+i+j+k)
                    
    def add_particle(self, particle_index):
        """ Add a particle to this cell. """
        self.particles.append(particle_index)
