From 1f8d3b69cde95fd1a74ec396b3677a6dc9f046a7 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <gamelin@synchrotron-soleil.fr>
Date: Fri, 6 Mar 2020 18:20:49 +0100
Subject: [PATCH] Mpi class for beam and mpi_gather function for beam

New parallel module to handle Mpi class which is used in Beam
---
 Tracking/beam.py            | 26 ++++++++++++++--
 Tracking/one_turn_matrix.py | 10 +++----
 Tracking/parallel.py        | 60 +++++++++++++++++++++++++++++++++++++
 Tracking/synchrotron.py     | 16 ++--------
 4 files changed, 90 insertions(+), 22 deletions(-)
 create mode 100644 Tracking/parallel.py

diff --git a/Tracking/beam.py b/Tracking/beam.py
index 2b0e039..b7041d2 100644
--- a/Tracking/beam.py
+++ b/Tracking/beam.py
@@ -8,6 +8,7 @@ Beam and bunch elements
 
 import numpy as np
 import pandas as pd
+from Tracking.parallel import Mpi
 
 class Bunch:
     """
@@ -257,7 +258,7 @@ class Beam:
     
     def __init__(self, ring, bunch_list=None):
         self.ring = ring
-        
+        self.mpi_switch = False
         if bunch_list is None:
             self.init_beam(np.zeros((self.ring.h,1),dtype=bool))
         else:
@@ -348,7 +349,7 @@ class Beam:
         """Return an array with the filling pattern of the beam as bool"""
         filling_pattern = []
         for bunch in self:
-            if filling_pattern != 0:
+            if bunch.current != 0:
                 filling_pattern.append(True)
             else:
                 filling_pattern.append(False)
@@ -413,4 +414,25 @@ class Beam:
         for index, bunch in enumerate(self):
             bunch_emit[:,index] = np.squeeze(bunch.emit)
         return bunch_emit
+    
+    def mpi_init(self):
+        """ Initialise mpi """
+        self.mpi = Mpi(self.filling_pattern)
+        self.mpi_switch = True
+        
+    def mpi_gather(self):
+        """Gather beam, all bunches of the different processors are sent to 
+        all processors. Rather slow"""
+        
+        if(self.mpi_switch == False):
+            print("Error, mpi is not initialised.")
+        
+        bunch = self[self.mpi.bunch_num]
+        bunches = self.mpi.comm.allgather(bunch)
+        for rank in range(self.mpi.size):
+            self[self.mpi.rank_to_bunch(rank)] = bunches[rank]
+        
+        
+        
+        
     
\ No newline at end of file
diff --git a/Tracking/one_turn_matrix.py b/Tracking/one_turn_matrix.py
index bdf74c7..45452db 100644
--- a/Tracking/one_turn_matrix.py
+++ b/Tracking/one_turn_matrix.py
@@ -32,9 +32,8 @@ class Element(metaclass=ABCMeta):
         def wrapper(*args, **kwargs):
             self = args[0]
             beam = args[1]
-            if (self.ring.mpi == True):
-                rank = self.ring.mpi_init()
-                function(self, beam[rank], *args[2:], **kwargs)
+            if (beam.mpi_switch == True):
+                function(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
             else:
                 for bunch in beam:
                     function(self, bunch, *args[2:], **kwargs)
@@ -48,9 +47,8 @@ class Element(metaclass=ABCMeta):
         def wrapper(*args, **kwargs):
             self = args[0]
             beam = args[1]
-            if (self.ring.mpi == True):
-                rank = self.ring.mpi_init()
-                function(self, beam[rank], *args[2:], **kwargs)
+            if (beam.mpi_switch == True):
+                function(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
             else:
                 for bunch in beam.not_empty:
                     function(self, bunch, *args[2:], **kwargs)
diff --git a/Tracking/parallel.py b/Tracking/parallel.py
new file mode 100644
index 0000000..c349c12
--- /dev/null
+++ b/Tracking/parallel.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+
+"""
+Module to handle parallel computation
+
+@author: Alexis Gamelin
+@date: 06/03/2020
+"""
+
+import numpy as np
+
+class Mpi:
+    """ Class which handle mpi
+    
+    """
+    
+    def __init__(self, filling_pattern):
+        try:
+            from mpi4py import MPI
+        except(ModuleNotFoundError):
+            print("Please install mpi4py module.")
+        else:
+            print("MPI error.")
+        self.comm = MPI.COMM_WORLD
+        self.rank = self.comm.Get_rank()
+        self.size = self.comm.Get_size()
+        self.write_table(filling_pattern)
+        
+    def write_table(self, filling_pattern):
+        """
+        Write a table with the rank and the corresponding bunch number for each
+        bunch of the filling pattern
+        """
+        if(filling_pattern.sum() != self.size):
+            raise ValueError("The number of processors must be equal to the"
+                             "number of (non-empty) bunches.")
+        table = np.zeros((self.size, 2), dtype = int)
+        table[:,0] = np.arange(0, self.size)
+        table[:,1] = np.where(filling_pattern)[0]
+        self.table = table
+    
+    def rank_to_bunch(self, rank):
+        """Return the bunch number corresponding to rank"""
+        return self.table[rank,1]
+    
+    def bunch_to_rank(self, bunch_num):
+        """Return the rank corresponding to the bunch number bunch_num"""
+        try:
+            rank = np.where(self.table[:,1] == bunch_num)[0][0]
+        except IndexError:
+            print("The bunch " + str(bunch_num) + " is not tracked on any processor.")
+            rank = None
+        return rank
+    
+    @property
+    def bunch_num(self):
+        """Return the bunch number corresponding to the current processor"""
+        return self.rank_to_bunch(self.rank)
+    
+    
\ No newline at end of file
diff --git a/Tracking/synchrotron.py b/Tracking/synchrotron.py
index fc99fcb..38e1e23 100644
--- a/Tracking/synchrotron.py
+++ b/Tracking/synchrotron.py
@@ -68,9 +68,7 @@ class Synchrotron:
         self.emit = kwargs.get('emit') # X/Y emittances in [m.rad]
         self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
         self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
-        
-        self.mpi = False
-        
+                
     @property
     def h(self):
         """Harmonic number"""
@@ -183,14 +181,4 @@ class Synchrotron:
     def eta(self):
         """Momentum compaction"""
         return self.ac - 1/(self.gamma**2)
-    
-    def mpi_init(self):
-        try:
-            from mpi4py import MPI
-        except(ModuleNotFoundError):
-            print("mpi4py not found")
-        comm = MPI.COMM_WORLD
-        rank = comm.Get_rank()
-        return rank
-
-        
\ No newline at end of file
+    
\ No newline at end of file
-- 
GitLab