# -*- coding: utf-8 -*-

"""
Module to handle parallel computation

@author: Alexis Gamelin
@date: 06/03/2020
"""

import numpy as np

class Mpi:
    """
    Class which handle parallel computation via the mpi4py module [1].
    
    Parameters
    ----------
    filling_pattern : bool array of shape (h,)
        Filling pattern of the beam, like Beam.filling_pattern
        
    Attributes
    ----------
    comm : MPI.Intracomm object
        MPI intra-comminicator of the processor group, used to manage 
        communication between processors. 
    rank : int
        Rank of the processor which run the program
    size : int
        Number of processor within the processor group (in fact in the 
        intra-comminicator group)
    table : int array of shape (size, 2)
        Table of correspondance between the rank of the processor and its 
        associated bunch number
    bunch_num : int
        Return the bunch number corresponding to the current processor
        
    Methods
    -------
    write_table(filling_pattern)
        Write a table with the rank and the corresponding bunch number for each
        bunch of the filling pattern
    rank_to_bunch(rank)
        Return the bunch number corresponding to rank
    bunch_to_rank(bunch_num)
        Return the rank corresponding to the bunch number bunch_num
        
    References
    ----------
    [1] L. Dalcin, P. Kler, R. Paz, and A. Cosimo, Parallel Distributed 
    Computing using Python, Advances in Water Resources, 34(9):1124-1139, 2011.
    """
    
    def __init__(self, filling_pattern):
        try:
            from mpi4py import MPI
        except(ModuleNotFoundError):
            print("Please install mpi4py module.")

        self.comm = MPI.COMM_WORLD
        self.rank = self.comm.Get_rank()
        self.size = self.comm.Get_size()
        self.write_table(filling_pattern)
        
    def write_table(self, filling_pattern):
        """
        Write a table with the rank and the corresponding bunch number for each
        bunch of the filling pattern
        
        Parameters
        ----------
        filling_pattern : bool array of shape (h,)
            Filling pattern of the beam, like Beam.filling_pattern
        """
        if(filling_pattern.sum() != self.size):
            raise ValueError("The number of processors must be equal to the"
                             "number of (non-empty) bunches.")
        table = np.zeros((self.size, 2), dtype = int)
        table[:,0] = np.arange(0, self.size)
        table[:,1] = np.where(filling_pattern)[0]
        self.table = table
    
    def rank_to_bunch(self, rank):
        """
        Return the bunch number corresponding to rank
        
        Parameters
        ----------
        rank : int
            Rank of a processor
            
        Returns
        -------
        bunch_num : int
            Bunch number corresponding to the input rank
        """
        return self.table[rank,1]
    
    def bunch_to_rank(self, bunch_num):
        """
        Return the rank corresponding to the bunch number bunch_num
        
        Parameters
        ----------
        bunch_num : int
            Bunch number
            
        Returns
        -------
        rank : int
            Rank of the processor which tracks the input bunch number
        """
        try:
            rank = np.where(self.table[:,1] == bunch_num)[0][0]
        except IndexError:
            print("The bunch " + str(bunch_num) + " is not tracked on any processor.")
            rank = None
        return rank
    
    @property
    def bunch_num(self):
        """Return the bunch number corresponding to the current processor"""
        return self.rank_to_bunch(self.rank)