#!/usr/bin/python # author J.Beale """ # aim to merge .stream files and calculate statistics # usage python partialator.py -s -n name (name of job - default = partialator) -p pointgroup -m model (unity or xsphere - default is unity) -i iterations - number of iterations in partialator -c -b number of resolution bins - must be > 20 -r high-res limt. Needs a default. Default set to 1.3 -a max-adu. Default = 12000 -v ra reservation name if available # output - scaled/merged files - an mtz file - useful plots - useful summerized .dat files - log file of output """ # modules from sys import exit import pandas as pd import numpy as np import subprocess import os, errno import time import argparse from tqdm import tqdm import regex as re import matplotlib.pyplot as plt from scipy.optimize import curve_fit import warnings warnings.filterwarnings( "ignore", category=RuntimeWarning ) from loguru import logger def submit_job( job_file ): submit_cmd = [ "sbatch", "--" , job_file ] logger.info( "using slurm command = {0}".format( submit_cmd ) ) job_output = subprocess.check_output( submit_cmd ) logger.info( "submited job = {0}".format( job_output ) ) # scrub job id from - example Submitted batch job 742403 pattern = r"Submitted batch job (\d+)" job_id = re.search( pattern, job_output.decode().strip() ).group(1) return int( job_id ) def wait_for_jobs( job_ids, total_jobs ): with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar: while job_ids: completed_jobs = set() for job_id in job_ids: status_cmd = [ "squeue", "-h", "-j", str( job_id ) ] status = subprocess.check_output( status_cmd ) if not status: completed_jobs.add( job_id ) pbar.update( 1 ) job_ids.difference_update( completed_jobs ) time.sleep( 2 ) def run_compare_check( proc_dir, name, cell, shells, part_h_res ): # check file name check_run_file = "{0}/check_{1}.sh".format( proc_dir, name ) # write file check_sh = open( check_run_file, "w" ) check_sh.write( "#!/bin/sh\n\n" ) check_sh.write( "module purge\n" ) check_sh.write( "module use MX unstable\n" ) check_sh.write( "module load crystfel/0.10.2-rhel8\n" ) check_sh.write( "check_hkl --shell-file=mult.dat *.hkl -p {0} --nshells={1} --highres={2} &> check_hkl.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "check_hkl --ltest --ignore-negs --shell-file=ltest.dat *.hkl -p {0} --nshells={1} --highres={2} &> ltest.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "check_hkl --wilson --shell-file=wilson.dat *.hkl -p {0} --nshells={1} --highres={2} &> wilson.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "compare_hkl --fom=Rsplit --shell-file=Rsplit.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> Rsplit.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "compare_hkl --fom=cc --shell-file=cc.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> cc.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "compare_hkl --fom=ccstar --shell-file=ccstar.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> ccstar.log\n".format( cell, shells, part_h_res ) ) check_sh.close() # make file executable subprocess.call( [ "chmod", "+x", "{0}".format( check_run_file ) ] ) # add check script to log check_input = open( check_run_file, "r" ) logger.info( "check input file =\n{0}".format( check_input.read() ) ) check_input.close() # return check file name return check_run_file def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat ): # read all files into pd # function to sort out different column names def read_dat( dat, var ): # different columns names of each dat file if var == "cc": cols = [ "d(nm)", "cc", "nref", "d", "min", "max" ] elif var == "ccstar": cols = [ "1(nm)", "ccstar", "nref", "d", "min", "max" ] elif var == "mult": cols = [ "d(nm)", "nref", "poss", "comp", "obs", "mult", "snr", "I", "d", "min", "max" ] elif var == "rsplit": cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ] elif var == "wilson": cols = [ "bin", "s2", "d", "lnI", "nref" ] df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" ) return df # make df cc_df = read_dat( cc_dat, "cc" ) ccstar_df = read_dat( ccstar_dat, "ccstar" ) mult_df = read_dat( mult_dat, "mult" ) rsplit_df = read_dat( rsplit_dat, "rsplit" ) wilson_df = read_dat( wilson_dat, "wilson" ) # remove unwanted cols cc_df = cc_df[ [ "cc" ] ] ccstar_df = ccstar_df[ [ "ccstar" ] ] rsplit_df = rsplit_df[ [ "rsplit" ] ] wilson_df = wilson_df[ [ "lnI" ] ] # merge dfs stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df, wilson_df ], axis=1, join="inner" ) # make 1/d, 1/d^2 column stats_df[ "1_d" ] = 1 / stats_df.d stats_df[ "1_d2" ] = 1 / stats_df.d**2 # change nan to 0 stats_df = stats_df.fillna(0) return stats_df def get_metric( d2_series, cc_series, cut_off ): # Define the tanh function from scitbx def tanh(x, r, s0): z = (x - s0)/r return 0.5 * ( 1 - np.tanh(z) ) def arctanh( y, r, s0 ): return r * np.arctanh( 1 - 2*y ) + s0 # Fit the tanh to the data params, covariance = curve_fit( tanh, d2_series, cc_series ) # Extract the fitted parameters r, s0 = params # calculate cut-off point cc_stat = arctanh( cut_off, r, s0 ) # covert back from 1/d2 to d cc_stat = np.sqrt( ( 1 / cc_stat ) ) # get curve for plotting cc_tanh = tanh( d2_series, r, s0 ) return round( cc_stat, 3 ), cc_tanh def get_overall_cc(): # open cc log file cc_log_file = open( "cc.log" ) cc_log = cc_log_file.read() # regex example = Overall CC = 0.5970865 overcc_pattern = r"Overall\sCC\s=\s(\d\.\d+)" try: overcc = re.search( overcc_pattern, cc_log ).group(1) except AttributeError as e: overcc = np.nan return overcc def get_overall_rsplit(): # open rsplit log file rsplit_log_file = open( "Rsplit.log" ) rsplit_log = rsplit_log_file.read() # regex example = Overall Rsplit = 54.58 % overrsplit_pattern = r"Overall\sRsplit\s=\s(\d+\.\d+)" try: overrsplit = re.search( overrsplit_pattern, rsplit_log ).group(1) except AttributeError as e: overrsplit = np.nan return overrsplit def get_b(): # open rsplit log file wilson_log_file = open( "wilson.log" ) wilson_log = wilson_log_file.read() # regex example = B = 41.63 A^2 b_factor_pattern = r"B\s=\s(\d+\.\d+)\sA" try: b_factor = re.search( b_factor_pattern, wilson_log ).group(1) except AttributeError as e: b_factor = np.nan return b_factor def summary_fig( name, stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut ): def dto1_d( x ): return 1/x # plot results cc_fig, axs = plt.subplots(2, 2) cc_fig.suptitle( "cc and cc* vs resolution" ) # cc plot color = "tab:red" axs[0,0].set_xlabel( "1/d (1/A)" ) axs[0,0].set_ylabel( "CC", color=color ) axs[0,0].set_ylim( 0, 1 ) axs[0,0].axhline( y = 0.3, color="black", linestyle = "dashed" ) # plot cc axs[0,0].plot( stats_df[ "1_d" ], stats_df.cc, color=color ) # plot fit axs[0,0].plot( stats_df[ "1_d" ], cc_tanh, color="tab:grey", linestyle = "dashed" ) sax1 = axs[0,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) ) sax1.set_xlabel('d (A)') axs[0,0].tick_params( axis="y", labelcolor=color ) axs[0,0].text( 0.1, 0.1, "CC0.5 @ 0.3 = {0}".format( cc_cut ), fontsize = 8 ) # cc* plot color = "tab:blue" axs[0,1].set_xlabel( "1/d (1/A)" ) axs[0,1].set_ylabel( "CC*", color=color ) axs[0,1].set_ylim( 0, 1 ) axs[0,1].axhline( y = 0.7, color="black", linestyle = "dashed" ) axs[0,1].plot( stats_df[ "1_d" ], stats_df.ccstar, color=color ) # plot fit axs[0,1].plot( stats_df[ "1_d" ], ccstar_tanh, color="tab:grey", linestyle = "dashed" ) sax2 = axs[0,1].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) ) sax2.set_xlabel('d (A)') axs[0,1].tick_params( axis='y', labelcolor=color ) axs[0,1].text( 0.1, 0.1, "CC* @ 0.7 = {0}".format( ccstar_cut ) , fontsize = 8 ) # rsplit plot color = "tab:green" axs[1,0].set_xlabel( "1/d (1/A)" ) axs[1,0].set_ylabel( "Rsplit", color=color ) axs[1,0].plot( stats_df[ "1_d" ], stats_df.rsplit, color=color ) sax3 = axs[1,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) ) sax3.set_xlabel( 'd (A)' ) axs[1,0].tick_params( axis='y', labelcolor=color ) # wilson plot color = "tab:purple" axs[1,1].set_xlabel( "1/d**2 (1/A**2)" ) axs[1,1].set_ylabel( "lnI", color=color ) axs[1,1].plot( stats_df[ "1_d2" ], stats_df.lnI, color=color ) axs[1,1].tick_params( axis='y', labelcolor=color ) # save figure plt.tight_layout() plt.savefig( "{0}_plots.png".format( name ) ) def main( cwd, name, cell, shells, part_h_res ): # submitted job set submitted_job_ids = set() # now run the check and compare scripts print( "running check/compare" ) check_run_file = run_compare_check( cwd, name, cell, shells, part_h_res ) check_job_id = submit_job( check_run_file ) print( f"job submitted: {0}".format( check_job_id ) ) submitted_job_ids.add( check_job_id ) time.sleep(10) wait_for_jobs( submitted_job_ids, 1 ) print( "done" ) # stats files names cc_dat = "cc.dat" ccstar_dat = "ccstar.dat" mult_dat = "mult.dat" rsplit_dat = "Rsplit.dat" wilson_dat = "wilson.dat" # make summary data table stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat ) logger.info( "stats table from .dat file =\n{0}".format( stats_df.to_string() ) ) print_df = stats_df[ [ "1_d", "d", "min", "max", "nref", "poss", "comp", "obs", "mult", "snr", "I", "rsplit", "cc", "ccstar" ] ] print_df.to_csv( "{0}_summary_table.csv".format( name ), sep="\t", index=False ) # calculate cc metrics cc_cut, cc_tanh = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 ) ccstar_cut, ccstar_tanh = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 ) print( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) ) print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) ) logger.info( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) ) logger.info( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) ) # scrub other metrics overcc = get_overall_cc() overrsplit = get_overall_rsplit() b_factor = get_b() logger.info( "overall CC0.5 = {0}".format( overcc ) ) logger.info( "overall Rsplit = {0}".format( overrsplit ) ) logger.info( "overall B = {0}".format( b_factor ) ) # show plots summary_fig( name, stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut ) # move back to top dir os.chdir( cwd ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-n", "--name", help="name of check.", type=str, required=True ) parser.add_argument( "-c", "--cell_file", help="path to CrystFEL cell file for partialator.", type=os.path.abspath, required=True ) parser.add_argument( "-b", "--bins", help="number of resolution bins to use. Should be more than 20. Default = 20.", type=int, default=20 ) parser.add_argument( "-r", "--resolution", help="high res limit - need something here. Default set to 1.3.", type=float, default=1.3 ) parser.add_argument( "-d", "--debug", help="output debug to terminal.", type=bool, default=False ) args = parser.parse_args() # set loguru if not args.debug: logger.remove() logfile = "{0}.log".format( args.name ) logger.add( logfile, format="{message}", level="INFO") # run main cwd = os.getcwd() print( "top working directory = {0}".format( cwd ) ) main( cwd, args.name, args.cell_file, args.bins, args.resolution )