diff --git a/src/cristallina/__init__.py b/src/cristallina/__init__.py index e451f10..1b78d24 100644 --- a/src/cristallina/__init__.py +++ b/src/cristallina/__init__.py @@ -14,3 +14,8 @@ except PackageNotFoundError: # pragma: no cover __version__ = "unknown" finally: del version, PackageNotFoundError + + +from . import utils +from . import plot +from . import analysis diff --git a/src/cristallina/analysis.py b/src/cristallina/analysis.py index 2c93a22..756dea2 100644 --- a/src/cristallina/analysis.py +++ b/src/cristallina/analysis.py @@ -176,9 +176,10 @@ def get_contrast_images( ) -def fit_2d_gaussian(image, roi: Optional[ROI] = None): +def fit_2d_gaussian(image, roi: Optional[ROI] = None, plot=False): """ - 2D Gaussian fit using LMFit for a given image and an optional region of interest. + 2D Gaussian fit using LMFit for a given image and an optional region of interest. + plot=True as optional argument plots the fit results. Returns the x, y coordinates of the center and the results object which contains further fit statistics. @@ -218,5 +219,46 @@ def fit_2d_gaussian(image, roi: Optional[ROI] = None): else: center_x = result.params["centerx"].value center_y = result.params["centery"].value + + if plot == True: + + X, Y = np.meshgrid(np.arange(len_y), np.arange(len_x)) + Z = griddata((x, y), z, (X, Y), method='linear', fill_value=0) + fig, axs = plt.subplots(2, 2, figsize=(10, 10)) + + # vmax = np.nanpercentile(Z, 99.9) + vmax = np.max(Z) + + ax = axs[0, 0] + art = ax.pcolor(X, Y, Z, vmin=0, vmax=vmax, shading='auto') + plt.colorbar(art, ax=ax, label='z') + ax.set_title('Data') + + ax = axs[0, 1] + fit = model.func(X, Y, **result.best_values) + art = ax.pcolor(X, Y, fit, vmin=0, vmax=vmax, shading='auto') + plt.colorbar(art, ax=ax, label='z') + ax.set_title('Fit') + + ax = axs[1, 0] + fit = model.func(X, Y, **result.best_values) + art = ax.pcolor(X, Y, Z-fit, vmin=0, shading='auto') + plt.colorbar(art, ax=ax, label='z') + ax.set_title('Data - Fit') + + ax = axs[1, 1] + fit = model.func(X, Y, **result.best_values) + art = ax.pcolor(X, Y, fit, vmin=0, vmax=vmax, shading='auto') + ax.contour(X, Y, fit, 8, colors='r',alpha=0.4) + plt.colorbar(art, ax=ax, label='z') + ax.set_title('Data & Fit') + + for ax in axs.ravel(): + ax.set_xlabel('x') + ax.set_ylabel('y') + plt.suptitle('2D Gaussian fit results') + plt.tight_layout() + plt.show() + return center_x, center_y, result diff --git a/src/cristallina/utils.py b/src/cristallina/utils.py index bdcbb07..8e9604b 100644 --- a/src/cristallina/utils.py +++ b/src/cristallina/utils.py @@ -24,6 +24,24 @@ def scan_info(run_number, base_path=None, small_data=True): return scan +def get_scan_from_run_number_or_scan(run_number_or_scan,small_data=True): + """Returns SFScanInfo object from run number or SFScanInfo object (then just passes that to output)""" + if type(run_number_or_scan) == SFScanInfo: + scan = run_number_or_scan + else: + scan = scan_info(run_number_or_scan,small_data=small_data) + return scan + +def get_run_number_from_run_number_or_scan(run_number_or_scan,small_data=True): + """Returns run number from run number or SFScanInfo object""" + if type(run_number_or_scan) == int: + rn = run_number_or_scan + elif type(run_number_or_scan) == SFScanInfo: + rn = int(str(run_number_or_scan.fs)[-19:-15]) + else: + raise ValueError("Input must be an int or SFScanInfo object") + return rn + def channel_names(run_number,verbose=False): """Prints channel names for a given run_number or scan object""" if type(run_number) == SFScanInfo: @@ -79,16 +97,22 @@ def print_run_info( break -def process_run(run_number, rois,detector='JF16T03V01', calculate =None, only_shots=slice(None), n_jobs=cpu_count()): +def process_run(run_number, rois,detector='JF16T03V01', calculate =None, only_shots=slice(None), n_jobs=cpu_count()-2): """Process rois for a given detector. Save the results small data in the res/small_data/run... By default only sum of rois is calculated, [mean,std,img] can be added to the "calculate" optional parameter. """ - # Load scan object with SFScanInfo - scan = scan_info(run_number,small_data=False) + rn = get_run_number_from_run_number_or_scan(run_number) + # Load scan object with SFScanInfo + scan = scan_info(rn,small_data=False) + + # Make the small data folder if it doesn't exist + if not os.path.exists( heuristic_extract_smalldata_path() ): + os.mkdir( heuristic_extract_smalldata_path() ) + # Set the path for later small data saving path_with_run_folder = heuristic_extract_smalldata_path()+'run'+str(run_number).zfill(4) - + # Make the small data run folder if it doesn't exist if not os.path.exists( path_with_run_folder ): os.mkdir( path_with_run_folder ) @@ -117,8 +141,8 @@ def process_run(run_number, rois,detector='JF16T03V01', calculate =None, only_sh det_pids = data[detector].pids sd[roi.name] = det_pids[only_shots], data[detector][only_shots, bottom:top,left:right].sum(axis=(1, 2)) if calculate: - if 'mean' in calculate: - sd[roi.name+"_mean"] = (det_pids[only_shots], data[detector][only_shots, bottom:top,left:right].mean(axis=(1, 2))) + if 'means' in calculate: + sd[roi.name+"_means"] = (det_pids[only_shots], data[detector][only_shots, bottom:top,left:right].mean(axis=(1, 2))) if 'std' in calculate: sd[roi.name+"_std"] = (det_pids[only_shots], data[detector][only_shots, bottom:top,left:right].std(axis=(1, 2))) if 'img' in calculate: @@ -129,8 +153,8 @@ def process_run(run_number, rois,detector='JF16T03V01', calculate =None, only_sh # These channels have only one dataset per step of the scan, so we take the first pulseID sd[roi.name + "_info"] =([det_pids[0]], [f"roi {roi.name}: {left},{right}; {bottom},{top} (left, right, bottom, top)"]) - sd[roi.name + "_mean_img"] = ([det_pids[0]], [data[detector][:, bottom:top,left:right].mean(axis=(0))] ) - + sd[roi.name + "_mean_img"] = ([det_pids[0]], [data[detector][only_shots, bottom:top,left:right].mean(axis=(0))] ) + sd[roi.name + "_step_sum"] = ([det_pids[0]], [data[detector][only_shots, bottom:top,left:right].sum()] ) Parallel(n_jobs=n_jobs,verbose=10)(delayed(process_step)(i) for i in range(len(scan))) class ROI: