From fae90cbeaee5ef2507ca3d77f9e5fd9fbb7fb472 Mon Sep 17 00:00:00 2001
From: Ivan Usov <ivan.usov@psi.ch>
Date: Tue, 26 Apr 2022 16:37:20 +0200
Subject: [PATCH] Initial version of (m)hkl sorting

---
 pyzebra/app/panel_ccl_prepare.py |  71 +++++++++---
 pyzebra/sxtal_refgen.py          | 180 ++++++++++++++++++++++++++++++-
 2 files changed, 236 insertions(+), 15 deletions(-)

diff --git a/pyzebra/app/panel_ccl_prepare.py b/pyzebra/app/panel_ccl_prepare.py
index d84a46b..b89844d 100644
--- a/pyzebra/app/panel_ccl_prepare.py
+++ b/pyzebra/app/panel_ccl_prepare.py
@@ -14,6 +14,7 @@ from bokeh.models import (
     Div,
     FileInput,
     MultiSelect,
+    NumericInput,
     Panel,
     Plot,
     RadioGroup,
@@ -47,6 +48,9 @@ for (let i = 0; i < js_data.data['fname'].length; i++) {
 }
 """
 
+ANG_CHUNK_DEFAULTS = {"2theta": 30, "gamma": 30, "omega": 30, "chi": 35, "phi": 35, "nu": 10}
+SORT_OPT_BI = ["2theta", "chi", "phi", "omega"]
+SORT_OPT_NB = ["gamma", "nu", "omega"]
 
 def create():
     ang_lims = None
@@ -157,12 +161,33 @@ def create():
     magstruct_lattice = TextInput(title="lattice", width=100)
     magstruct_kvec = TextAreaInput(title="k vector", width=150)
 
+    def sorting0_callback(_attr, _old, new):
+        sorting_0_dt.value = ANG_CHUNK_DEFAULTS[new]
+
+    def sorting1_callback(_attr, _old, new):
+        sorting_1_dt.value = ANG_CHUNK_DEFAULTS[new]
+
+    def sorting2_callback(_attr, _old, new):
+        sorting_2_dt.value = ANG_CHUNK_DEFAULTS[new]
+
+    sorting_0 = Select(title="1st", width=100)
+    sorting_0.on_change("value", sorting0_callback)
+    sorting_0_dt = NumericInput(title="Δ", width=70)
+    sorting_1 = Select(title="2nd", width=100)
+    sorting_1.on_change("value", sorting1_callback)
+    sorting_1_dt = NumericInput(title="Δ", width=70)
+    sorting_2 = Select(title="3rd", width=100)
+    sorting_2.on_change("value", sorting2_callback)
+    sorting_2_dt = NumericInput(title="Δ", width=70)
+
     def geom_radiogroup_callback(_attr, _old, new):
         nonlocal ang_lims, params
         if new == 0:
             geom_file = pyzebra.get_zebraBI_default_geom_file()
+            sort_opt = SORT_OPT_BI
         else:
             geom_file = pyzebra.get_zebraNB_default_geom_file()
+            sort_opt = SORT_OPT_NB
         cfl_file = pyzebra.get_zebra_default_cfl_file()
 
         ang_lims = pyzebra.read_geom_file(geom_file)
@@ -170,6 +195,11 @@ def create():
         params = pyzebra.read_cfl_file(cfl_file)
         _update_params(params)
 
+        sorting_0.options = sorting_1.options = sorting_2.options = sort_opt
+        sorting_0.value = sort_opt[0]
+        sorting_1.value = sort_opt[1]
+        sorting_2.value = sort_opt[2]
+
     geom_radiogroup_div = Div(text="Geometry:", margin=(5, 5, 0, 5))
     geom_radiogroup = RadioGroup(labels=["bisecting", "normal beam"], width=150)
     geom_radiogroup.on_change("active", geom_radiogroup_callback)
@@ -208,6 +238,14 @@ def create():
             with open(geom_path) as f:
                 print(f.read())
 
+            priority = [sorting_0.value, sorting_1.value, sorting_2.value]
+            chunks = [sorting_0_dt.value, sorting_1_dt.value, sorting_2_dt.value]
+            if geom_radiogroup.active == 0:
+                sort_hkl_file = pyzebra.sort_hkl_file_bi
+                priority.extend(set(SORT_OPT_BI) - set(priority))
+            else:
+                sort_hkl_file = pyzebra.sort_hkl_file_nb
+
             # run sxtal_refgen for each kvect provided
             for i, kvect in enumerate(kvects, start=1):
                 params["kvect"] = kvect
@@ -240,26 +278,32 @@ def create():
 
                 if i == 1:  # all hkl files are identical, so keep only one
                     hkl_fname = base_fname + ".hkl"
-                    with open(os.path.join(temp_dir, hkl_fname)) as f:
+                    hkl_fpath = os.path.join(temp_dir, hkl_fname)
+                    with open(hkl_fpath) as f:
                         res_files[hkl_fname] = f.read()
 
+                    hkl_fname_sorted = base_fname + "_sorted.hkl"
+                    hkl_fpath_sorted = os.path.join(temp_dir, hkl_fname_sorted)
+                    sort_hkl_file(hkl_fpath, hkl_fpath_sorted, priority, chunks)
+                    with open(hkl_fpath_sorted) as f:
+                        res_files[hkl_fname_sorted] = f.read()
+
                 mhkl_fname = base_fname + ".mhkl"
-                with open(os.path.join(temp_dir, mhkl_fname)) as f:
+                mhkl_fpath = os.path.join(temp_dir, mhkl_fname)
+                with open(mhkl_fpath) as f:
                     res_files[mhkl_fname] = f.read()
 
+                mhkl_fname_sorted = base_fname + "_sorted.mhkl"
+                mhkl_fpath_sorted = os.path.join(temp_dir, hkl_fname_sorted)
+                sort_hkl_file(mhkl_fpath, mhkl_fpath_sorted, priority, chunks)
+                with open(mhkl_fpath_sorted) as f:
+                    res_files[mhkl_fname_sorted] = f.read()
+
             created_lists.options = list(res_files)
 
     go_button = Button(label="GO", button_type="primary", width=50)
     go_button.on_click(go_button_callback)
 
-    sorting_cb = CheckboxGroup(labels=["Apply sorting"], width=120)
-    sorting_1 = Select(title="1st", width=70)
-    sorting_1_dt = TextInput(title="Δ", width=70)
-    sorting_2 = Select(title="2nd", width=70)
-    sorting_2_dt = TextInput(title="Δ", width=70)
-    sorting_3 = Select(title="3rd", width=70)
-    sorting_3_dt = TextInput(title="Δ", width=70)
-
     def created_lists_callback(_attr, _old, new):
         sel_file = new[0]
         file_text = res_files[sel_file]
@@ -304,15 +348,14 @@ def create():
     ranges_layout = column(ranges_div, row(ranges_hkl, ranges_srang))
     magstruct_layout = column(magstruct_div, row(magstruct_lattice, magstruct_kvec))
     sorting_layout = row(
-        column(Spacer(height=25), sorting_cb),
+        sorting_0,
+        sorting_0_dt,
+        Spacer(width=30),
         sorting_1,
         sorting_1_dt,
         Spacer(width=30),
         sorting_2,
         sorting_2_dt,
-        Spacer(width=30),
-        sorting_3,
-        sorting_3_dt,
     )
 
     column1_layout = column(
diff --git a/pyzebra/sxtal_refgen.py b/pyzebra/sxtal_refgen.py
index 7a10f2a..6058b7a 100644
--- a/pyzebra/sxtal_refgen.py
+++ b/pyzebra/sxtal_refgen.py
@@ -1,7 +1,10 @@
 import io
 import os
-import tempfile
 import subprocess
+import tempfile
+from math import ceil, floor
+
+import numpy as np
 
 SXTAL_REFGEN_PATH = "/afs/psi.ch/project/sinq/rhel7/bin/Sxtal_Refgen"
 
@@ -303,3 +306,178 @@ def export_cfl_file(path, params, template=None):
             out_file.write("\n")
             for atom_line in params["ATOM"]:
                 out_file.write(f"ATOM {atom_line}\n")
+
+
+def sort_hkl_file_bi(file_in, file_out, priority, chunks):
+    with open(file_in) as fileobj:
+        file_in_data = fileobj.readlines()
+
+    data = np.genfromtxt(file_in, skip_header=3)
+    stt = data[:, 4]
+    omega = data[:, 5]
+    chi = data[:, 6]
+    phi = data[:, 7]
+
+    lines = file_in_data[3:]
+    lines_update = []
+
+    angles = {"2theta": stt, "omega": omega, "chi": chi, "phi": phi}
+
+    # Reverse flag
+    to_reverse = False
+    to_reverse_p2 = False
+    to_reverse_p3 = False
+
+    # Get indices within first priority
+    ang_p1 = angles[priority[0]]
+    begin_p1 = floor(min(ang_p1))
+    end_p1 = ceil(max(ang_p1))
+    delta_p1 = chunks[0]
+    for p1 in range(begin_p1, end_p1, delta_p1):
+        ind_p1 = [j for j, x in enumerate(ang_p1) if p1 <= x and x < p1 + delta_p1]
+
+        stt_new = [stt[x] for x in ind_p1]
+        omega_new = [omega[x] for x in ind_p1]
+        chi_new = [chi[x] for x in ind_p1]
+        phi_new = [phi[x] for x in ind_p1]
+        lines_new = [lines[x] for x in ind_p1]
+
+        angles_p2 = {"stt": stt_new, "omega": omega_new, "chi": chi_new, "phi": phi_new}
+
+        # Get indices for second priority
+        ang_p2 = angles_p2[priority[1]]
+        if len(ang_p2) > 0 and to_reverse_p2:
+            begin_p2 = ceil(max(ang_p2))
+            end_p2 = floor(min(ang_p2))
+            delta_p2 = -chunks[1]
+        elif len(ang_p2) > 0 and not to_reverse_p2:
+            end_p2 = ceil(max(ang_p2))
+            begin_p2 = floor(min(ang_p2))
+            delta_p2 = chunks[1]
+        else:
+            end_p2 = 0
+            begin_p2 = 0
+            delta_p2 = 1
+
+        to_reverse_p2 = not to_reverse_p2
+
+        for p2 in range(begin_p2, end_p2, delta_p2):
+            min_p2 = min([p2, p2 + delta_p2])
+            max_p2 = max([p2, p2 + delta_p2])
+            ind_p2 = [j for j, x in enumerate(ang_p2) if min_p2 <= x and x < max_p2]
+
+            stt_new2 = [stt_new[x] for x in ind_p2]
+            omega_new2 = [omega_new[x] for x in ind_p2]
+            chi_new2 = [chi_new[x] for x in ind_p2]
+            phi_new2 = [phi_new[x] for x in ind_p2]
+            lines_new2 = [lines_new[x] for x in ind_p2]
+
+            angles_p3 = {"stt": stt_new2, "omega": omega_new2, "chi": chi_new2, "phi": phi_new2}
+
+            # Get indices for third priority
+            ang_p3 = angles_p3[priority[2]]
+            if len(ang_p3) > 0 and to_reverse_p3:
+                begin_p3 = ceil(max(ang_p3)) + chunks[2]
+                end_p3 = floor(min(ang_p3)) - chunks[2]
+                delta_p3 = -chunks[2]
+            elif len(ang_p3) > 0 and not to_reverse_p3:
+                end_p3 = ceil(max(ang_p3)) + chunks[2]
+                begin_p3 = floor(min(ang_p3)) - chunks[2]
+                delta_p3 = chunks[2]
+            else:
+                end_p3 = 0
+                begin_p3 = 0
+                delta_p3 = 1
+
+            to_reverse_p3 = not to_reverse_p3
+
+            for p3 in range(begin_p3, end_p3, delta_p3):
+                min_p3 = min([p3, p3 + delta_p3])
+                max_p3 = max([p3, p3 + delta_p3])
+                ind_p3 = [j for j, x in enumerate(ang_p3) if min_p3 <= x and x < max_p3]
+
+                angle_new3 = [angles_p3[priority[3]][x] for x in ind_p3]
+
+                ind_final = [x for _, x in sorted(zip(angle_new3, ind_p3), reverse=to_reverse)]
+
+                to_reverse = not to_reverse
+
+                for i in ind_final:
+                    lines_update.append(lines_new2[i])
+
+    with open(file_out, "w") as fileobj:
+        for _ in range(3):
+            fileobj.write(file_in_data.pop(0))
+
+        fileobj.writelines(lines_update)
+
+
+def sort_hkl_file_nb(file_in, file_out, priority, chunks):
+    with open(file_in) as fileobj:
+        file_in_data = fileobj.readlines()
+
+    data = np.genfromtxt(file_in, skip_header=3)
+    gamma = data[:, 4]
+    omega = data[:, 5]
+    nu = data[:, 6]
+
+    lines = file_in_data[3:]
+    lines_update = []
+
+    angles = {"gamma": gamma, "omega": omega, "nu": nu}
+
+    to_reverse = False
+    to_reverse_p2 = False
+
+    # Get indices within first priority
+    ang_p1 = angles[priority[0]]
+    begin_p1 = floor(min(ang_p1))
+    end_p1 = ceil(max(ang_p1))
+    delta_p1 = chunks[0]
+    for p1 in range(begin_p1, end_p1, delta_p1):
+        ind_p1 = [j for j, x in enumerate(ang_p1) if p1 <= x and x < p1 + delta_p1]
+
+        # Get angles from within nu range
+        lines_new = [lines[x] for x in ind_p1]
+        gamma_new = [gamma[x] for x in ind_p1]
+        omega_new = [omega[x] for x in ind_p1]
+        nu_new = [nu[x] for x in ind_p1]
+
+        angles_p2 = {"gamma": gamma_new, "omega": omega_new, "nu": nu_new}
+
+        # Get indices for second priority
+        ang_p2 = angles_p2[priority[1]]
+        if len(gamma_new) > 0 and to_reverse_p2:
+            begin_p2 = ceil(max(ang_p2))
+            end_p2 = floor(min(ang_p2))
+            delta_p2 = -chunks[1]
+        elif len(gamma_new) > 0 and not to_reverse_p2:
+            end_p2 = ceil(max(ang_p2))
+            begin_p2 = floor(min(ang_p2))
+            delta_p2 = chunks[1]
+        else:
+            end_p2 = 0
+            begin_p2 = 0
+            delta_p2 = 1
+
+        to_reverse_p2 = not to_reverse_p2
+
+        for p2 in range(begin_p2, end_p2, delta_p2):
+            min_p2 = min([p2, p2 + delta_p2])
+            max_p2 = max([p2, p2 + delta_p2])
+            ind_p2 = [j for j, x in enumerate(ang_p2) if min_p2 <= x and x < max_p2]
+
+            angle_new2 = [angles_p2[priority[2]][x] for x in ind_p2]
+
+            ind_final = [x for _, x in sorted(zip(angle_new2, ind_p2), reverse=to_reverse)]
+
+            to_reverse = not to_reverse
+
+            for i in ind_final:
+                lines_update.append(lines_new[i])
+
+    with open(file_out, "w") as fileobj:
+        for _ in range(3):
+            fileobj.write(file_in_data.pop(0))
+
+        fileobj.writelines(lines_update)