import numpy as np
from math import sqrt

# This file contains several different classes for calculating the distance between two sets of coordinates 
# according to different metrics and different types of input coordinates. It is used as a library by other 
# functions in the Cage Breaking software.

class get_distance(object):
    """ The simplest distance calculator. Input is two cartesian coordinate arrays, of shape (m,3) where m
        is the number of particles. Output is an array of shape (m) containing a list of distances between the
        two configurations for each particle.""" 
    def __init__(self):
        pass

    def __call__(self, xi, xj):
        """ Warning: this preserves the shape of the coords array passed in, so won't
            correct or complain if you give it an oddly formatted array (in particular,
            it won't complain if it receives AA coords) """
        dx = xj - xi
        if(dx.ndim>1):
            return np.linalg.norm(dx,axis=1)
        else:
            return np.linalg.norm(dx)

class get_dist_periodic(object):
    """ Calculates the shortest periodic distance between two configurations of a system. Input is two cartesian
        coordinate arrays of shape (m,3). Output is an array of shape (m) containing the distances between the two
        closest images of each particle in the two configurations."""
    def __init__(self, boxvec):
        """ Boxvec is an array of size 3 containing the three side lengths for the periodic box (note, only cubic
            boxes are supported at present). Note, the i'th coordinates of the box run from -boxvec[i]/2 to +boxvec[i]/2 """
        self.boxvec = boxvec

    def __call__(self, xi, xj):
        """ Warning: this preserves the shape of the coords array passed in, so won't
            correct or complain if you give it an oddly formatted array (in particular,
            it won't complain if it receives AA coords)
        """
        dx = xj - xi
        dx -= self.boxvec * np.round(dx / self.boxvec)
        if(dx.ndim>1):
            return np.linalg.norm(dx,axis=1)
        else:
            return np.linalg.norm(dx)
