diff --git a/pyzebra/ccl_process.py b/pyzebra/ccl_process.py index 9ef8cd4..6ba08f6 100644 --- a/pyzebra/ccl_process.py +++ b/pyzebra/ccl_process.py @@ -99,30 +99,39 @@ def merge_scans(scan_into, scan_from): scan_motor = scan_into["scan_motor"] # the same as scan_from["scan_motor"] - if ( - scan_into[scan_motor].shape == scan_from[scan_motor].shape - and np.max(np.abs(scan_into[scan_motor] - scan_from[scan_motor])) < 0.0005 - ): - counts_tmp = 0 - counts_err_tmp = 0 + pos_all = np.array([]) + val_all = np.array([]) + err_all = np.array([]) + for scan in [scan_into["init_scan"], *scan_into["merged_scans"]]: + pos_all = np.append(pos_all, scan[scan_motor]) + val_all = np.append(val_all, scan["counts"]) + err_all = np.append(err_all, scan["counts_err"] ** 2) - for scan in [scan_into["init_scan"], *scan_into["merged_scans"]]: - counts_tmp += scan["counts"] - counts_err_tmp += scan["counts_err"] ** 2 + sort_index = np.argsort(pos_all) + pos_all = pos_all[sort_index] + val_all = val_all[sort_index] + err_all = err_all[sort_index] - scan_into["counts"] = counts_tmp / (1 + len(scan_into["merged_scans"])) - scan_into["counts_err"] = np.sqrt(counts_err_tmp) + pos_tmp = pos_all[:1] + val_tmp = val_all[:1] + err_tmp = err_all[:1] + num_tmp = np.array([1]) + for pos, val, err in zip(pos_all[1:], val_all[1:], err_all[1:]): + if pos - pos_tmp[-1] < 0.0005: + # the repeated motor position + val_tmp[-1] += val + err_tmp[-1] += err + num_tmp[-1] += 1 + else: + # a new motor position + pos_tmp = np.append(pos_tmp, pos) + val_tmp = np.append(val_tmp, val) + err_tmp = np.append(err_tmp, err) + num_tmp = np.append(num_tmp, 1) - else: - motor_pos = np.concatenate((scan_into[scan_motor], scan_from[scan_motor])) - counts = np.concatenate((scan_into["counts"], scan_from["counts"])) - counts_err = np.concatenate((scan_into["counts_err"], scan_from["counts_err"])) - - index = np.argsort(motor_pos) - - scan_into[scan_motor] = motor_pos[index] - scan_into["counts"] = counts[index] - scan_into["counts_err"] = counts_err[index] + scan_into[scan_motor] = pos_tmp + scan_into["counts"] = val_tmp / num_tmp + scan_into["counts_err"] = np.sqrt(err_tmp) scan_from["export"] = False