import numpy as np
import distances
import checkneighs
import checkcb
import CellLists
import reversal_manager
from read_configurations import ReadConfigurations

class cb_manager(object):
    """Class to manage cage-breaking analysis on a set of configurations.
    There are two modes of operation. In the first, an ordered series of configurations is passed
    to the cb_manager, which identifies which atoms have undergone a cage break since the previous
    configuration. In the second, a pair of configurations is passed to the cb_manager, which identifies
    which atoms undergo a cage break between those two configurations.

    For the first mode, the functions start_chain, continue_chain and terminate_chain are used to
    pass in configurations. For the second mode, the function process_pair is called for each pair
    of configurations.

    A checkcb object must be defined and given to the cb_manager as an attribute at initialisation.
    This is what the cb_manager uses to identify cage breaks.
    
    Parameters
    ----------
    checkcb: checkcb object
        An object which identifies atoms that undergo a cage break between two configurations
    outname: string
        A path to the file where details of the cage breaks will be written
                
    Attributes
    ----------
    cb_list: np array
        A list which will hold dictionaries that record which atoms underwent a cage break in a
        particular pair of configurations       
    """
    def __init__(self, checkcb, outname=None):
        self.checkcb = checkcb
        
        self.cb_list = []
        self.config_count = 0            
        
        if(outname is None):
            self.outfile = False
        else:
            self.outfile = True
            self.out = open(outname,'w')        
            
    def start_chain(self, x0):
        """ Call this to indicate that configurations will be passed in as an ordered sequence and
        that cage breaks should be identified for every pair in the sequence. x0 is the first configuration
        in the sequence, an np.array containing the coordinates of the system in a format which will be
        recognised by your checkcb object.
        
        Parameters
        ----------
        x0: array of floats
            A numpy as array containing coordinates of the first configuration in the chain.
            These coordinates may be given in any format recognised by the checkcb object. (Normally atomistic
            cartesian or Centre-of-Mass + Angle-Axis
        """
        self.last_x = x0              
        self.checkcb.start_chain(x0)
        
    def continue_chain(self, x1):
        """ Check a new configuration for CBs when using the 'chain mode' of the manager.
        Any cage breaks which were identified are returned in the "cbs" dictionary, and added to cb_list, a list
        of dictionaries with the same format as cbs.
        
        Parameters
        ----------
        x1: array of floats
            A numpy as array containing coordinates of the next configuration in the chain.
            These coordinates may be given in any format recognised by the checkcb object. (Normally atomistic
            cartesian or Centre-of-Mass + Angle-Axis
            
        Returns
        -------
        cbs: dict
            A dictionary with one special key ('frame') that corresponds to the number of this configuration in
            the chain. Other keys correspond to the indices of atoms which have undergone cage breaks since the
            preceding timestep; the corresponding values are cage_break objects holding data about that cage break.
        """
        self.config_count += 1
               
        # checkcb returns a dictionary containing the timestep of the transition (given by self.config_count) and
        # a set of entries of the form "atom_index: cb" where cb is a cage_break object.
        cbs = self.checkcb(self.last_x, x1, self.config_count)
                      
        self.cb_list.append(cbs)
        self.last_x = x1.copy()        
        
        if self.outfile:
            for i in cbs.keys():
                if i != 'frame':
                    # print out the essential information associated with each cage break that occurred in this step
                    self.out.write(cbs[i].lineprint())
                
        return cbs        
                
    def terminate_chain(self):
        """ Indicate that the last configuration of the chain has now been passed in. This causes the
        manager to flush through any remaining unterminated CB chains (note, there won't be any unless
        the checkcb object has been instructed to return only terminated CBs). Any CBs extracted by the
        flushing are added to the cb_list as though they were terminated at the last CB in the chain. 
        If output is being written to a file, we write some statistics about the cage breaks.
        
        Returns
        -------
        cbs: dict
            A dictionary containing the last cage_break object in each reversal chain at the time when
            the configuration chain was terminated. The value of 'frame' corresponds to the last timestep
            in the chain.        
        """
        cbs = self.checkcb.flush()

        if cbs is not None:
            self.cb_list.append(cbs)

            if self.outfile:
                for i in cbs.keys():
                    if i != 'frame':
                        self.out.write(cbs[i].lineprint())
        
        if self.outfile:
            self.final_statistics()
                
        return cbs     

    def process_pair(self,x0,x1):
        """ This is the function to use if you want to check an isolated pair of configurations for cage 
        breaks, rather than looking at configurations in a chain.
        
        Parameters
        ----------
        x0: array of floats
            A numpy as array containing coordinates of the first configuration in the pair.
            These coordinates may be given in any format recognised by the checkcb object. (Normally atomistic
            cartesian or Centre-of-Mass + Angle-Axis
        x1: array of floats
            A numpy as array containing coordinates of the second configuration in the pair.
            
        Returns
        -------
        cbs: dict
            A dictionary with one special key ('frame') that corresponds to the number of this pair of
            configurations. Other keys correspond to the indices of atoms which have undergone cage breaks since the
            preceding timestep; the corresponding values are cage_break objects holding data about that cage break.        
        """
        
        self.config_count+=1 # Here, this plays the role of a counter for transition states
        # Check the pair for cage breaks and extract a cbs object containing any cage breaks that took place.                   
        cbs = self.checkcb(x0, x1, time=self.config_count)
        if len(cbs.keys())==1:
            return None
        elif self.outfile:
            for i in cbs.keys():
                if i != 'frame':
                    self.out.write(cbs[i].lineprint())
        return cbs
               
    def final_statistics(self):
        """ Write some overall statistics about the identified cage breaks to the output file.
        Note, these statistics are intended for use with the chain-of-configurations method, and may not be
        meaningful for isolated configuration pairs passed in with process_pair. """
        if not self.outfile:
            print "Warning: printing final statistics with no outfile set"
            return
        
        N_CBs = 0
        Null_frames = 0
        
        Atoms_so_far = []
        N_CB_atoms = 0
        
        # Keys in l_dist correspond to different reversal chain lengths. The corresponding values are the number of
        # chains identified with that length.
        l_dist = {}
                          
        for i in self.cb_list:
            if len(i.keys())==1:
                Null_frames += 1
            for j in i.keys():
                if j != 'frame':
                    N_CBs += 1
                    if j not in Atoms_so_far:
                        Atoms_so_far.append(j)
                        N_CB_atoms += 1
                    if i[j].l in l_dist.keys():
                        l_dist[i[j].l] += 1
                    else:                        
                        l_dist[i[j].l] = 1
         
        self.out.write("Final statistics:\n")
        self.out.write("Number of cage breaking steps:"+str(self.config_count-Null_frames)+"\n")
        self.out.write("Number of cage breaks:"+str(N_CBs)+"\n")
        self.out.write("Number of cage breaking atoms:"+str(N_CB_atoms)+"\n")
        self.out.write("Distribution of reversals:\n")
        for i in sorted(l_dist.keys()):
            self.out.write(str(i)+" "+str(l_dist[i])+"\n")
                               
        self.out.close()   
            
if __name__ == '__main__':
    
    # An example of how this program might be used to produce a "breaksfile" - a list of the cage breaks which
    # take place in each transition of a quenched MD trajectory.
    
    #################################################################
    # Specify the input parameters necessary for the cb manager
    #################################################################
    # Number of molecules in an OTP system
    nmol = 324
    # Periodic simulation box side lengths
    boxvec = np.array([10.107196061523553, 10.107196061523553, 10.107196061523553])
    # Global cutoff to identify nearest neighbours (not a sensible value)
    nncutoff = 1.5
    # Movement cutoff to identify changes to the nearest neighbours (not a sensible value)
    cbcutoff = 0.1

    # A distance metric function
    dist = distances.get_dist_periodic(boxvec)
    # A cell list object
    cell_list = CellLists.cell_lists(boxvec, nncutoff)
    
    # A method to identify the nearest neighbours. Atomistic cutoff is not recommended for OTP
    check_neighs = checkneighs.atomistic_cutoff(dist, nncutoff, boxvec, nmol, AA=True, cell_list=cell_list)
    check_neighs.add_criterion(checkneighs.medium_criterion)
    
    # Reversal manager (needed for the cb_rule) to identify reversals
    reversals = reversal_manager.neighbours_reversal_manager(2, terminated_only=False)
    # cb_rule (needed for checkcb) to identify cage breaks for a particular molecule
    cbrule = checkcb.neighchange_dcut_cb_rule(dist, 2, cbcutoff, reversal_manager=reversals)
    # checkcb object to identify cage breaks in a particular pair of configurations 
    check_cb = checkcb.checkcb(check_neighs, cbrule)
    
    # The manager object itself, which controls the other classes and collates the output
    manager = cb_manager(check_cb,outname='test_cb_new.dat')
    
    #################################################################
    # Specify the input file and how to read the trajectory
    #################################################################    
    
    # The name of the input file, which contains the quenched trajectory.
    # We need to know how many columns are in this file in order to read it correctly (default 1) 
    fname, ncols = '/scratch/sn402/OTP/MDruns/T291/minimised_configs.dat', 3    

    # Object to read the trajectory file
    rc = ReadConfigurations(fname, nmol, ncols)
    
    
    #################################################################
    # Analyse the trajectory
    #################################################################      

    for x in rc.configuration_iterator():
        if(rc.configuration_count == 1):
            manager.start_chain(x)
        elif(rc.configuration_count%1 == 0):
            print "starting configuration ", rc.configuration_count-1
            manager.continue_chain(x)

        if(rc.configuration_count == 100):
            # For demonstration purposes, we analyse only the first 100 configurations, and then
            # stop (instructing the manager to terminate as we do so).
            manager.terminate_chain()
            break

    # The output from this test program is a file "test_cb_new" containing details of all cage breaks in
    # the first 100 configurations of this trajectory.
    # Sample input and output files are provided with this source code, and the README explains how to read
    # the output file.