#!/usr/bin/python # author J.Beale """ # aim to merge .stream files and calculate statistics and MTZ # usage python partialator.py -s -n name (name of job - default = partialator) -p pointgroup -m model (unity or xsphere - default is unity) -i iterations - number of iterations in partialator -c -b number of resolution bins - must be > 20 -r high-res limt. Needs a default. Default set to 1.3 -a max-adu. Default = 12000 -v ra reservation name if available -g spacegroup -c list of cell lengths and angles to use - 59.3,59.3,153.1,90.0,90.0,90.0 -r number of residues # output - scaled/merged files - an mtz file - useful plots - useful summerized .dat files - log file of output """ # modules from sys import exit import pandas as pd import numpy as np import subprocess import os, errno import time import argparse from tqdm import tqdm import regex as re import matplotlib.pyplot as plt from scipy.optimize import curve_fit import warnings warnings.filterwarnings( "ignore", category=RuntimeWarning ) from loguru import logger def submit_job( job_file, reservation ): # submit the job if reservation: print( "using a ra beamtime reservation = {0}".format( reservation ) ) logger.info( "using ra reservation to process data = {0}".format( reservation ) ) submit_cmd = [ "sbatch", "-p", "day", "--reservation={0}".format( reservation ), "--cpus-per-task=32", "--" , job_file ] else: submit_cmd = [ "sbatch", "-p", "day", "--cpus-per-task=32", "--" , job_file ] logger.info( "using slurm command = {0}".format( submit_cmd ) ) try: job_output = subprocess.check_output( submit_cmd ) logger.info( "submited job = {0}".format( job_output ) ) except subprocess.CalledProcessError as e: print( "please give the correct ra reservation or remove the -v from the arguements" ) exit() # scrub job id from - example Submitted batch job 742403 pattern = r"Submitted batch job (\d+)" job_id = re.search( pattern, job_output.decode().strip() ).group(1) return int( job_id ) def wait_for_jobs( job_ids, total_jobs ): with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar: while job_ids: completed_jobs = set() for job_id in job_ids: status_cmd = [ "squeue", "-h", "-j", str( job_id ) ] status = subprocess.check_output( status_cmd ) if not status: completed_jobs.add( job_id ) pbar.update( 1 ) job_ids.difference_update( completed_jobs ) time.sleep( 2 ) def run_partialator( proc_dir, name, stream, pointgroup, model, iterations, adu ): # partialator file name part_run_file = "{0}/partialator_{1}.sh".format( proc_dir, name ) # write file part_sh = open( part_run_file, "w" ) part_sh.write( "#!/bin/sh\n\n" ) part_sh.write( "module purge\n" ) part_sh.write( "module use MX unstable\n" ) part_sh.write( "module load crystfel/0.10.2-rhel8\n" ) part_sh.write( "partialator -i {0} \\\n".format( stream ) ) part_sh.write( " -o {0}.hkl \\\n".format( name ) ) part_sh.write( " -y {0} \\\n".format( pointgroup ) ) part_sh.write( " --model={0} \\\n".format( model ) ) part_sh.write( " --max-adu={0} \\\n".format( adu ) ) part_sh.write( " -j 32 \\\n" ) part_sh.write( " --iterations={0}\n\n".format( iterations ) ) part_sh.close() # make file executable subprocess.call( [ "chmod", "+x", "{0}".format( part_run_file ) ] ) # add partialator script to log part_input = open( part_run_file, "r" ) logger.info( "partialator input file =\n{0}".format( part_input.read() ) ) part_input.close() # return partialator file name return part_run_file def run_compare_check( proc_dir, name, cell, shells, part_h_res ): # check file name check_run_file = "{0}/check_{1}.sh".format( proc_dir, name ) # write file check_sh = open( check_run_file, "w" ) check_sh.write( "#!/bin/sh\n\n" ) check_sh.write( "module purge\n" ) check_sh.write( "module use MX unstable\n" ) check_sh.write( "module load crystfel/0.10.2-rhel8\n" ) check_sh.write( "check_hkl --shell-file=mult.dat *.hkl -p {0} --nshells={1} --highres={2} &> check_hkl.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "check_hkl --ltest --ignore-negs --shell-file=ltest.dat *.hkl -p {0} --nshells={1} --highres={2} &> ltest.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "check_hkl --wilson --shell-file=wilson.dat *.hkl -p {0} --nshells={1} --highres={2} &> wilson.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "compare_hkl --fom=Rsplit --shell-file=Rsplit.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> Rsplit.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "compare_hkl --fom=cc --shell-file=cc.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> cc.log\n".format( cell, shells, part_h_res ) ) check_sh.write( "compare_hkl --fom=ccstar --shell-file=ccstar.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> ccstar.log\n".format( cell, shells, part_h_res ) ) check_sh.close() # make file executable subprocess.call( [ "chmod", "+x", "{0}".format( check_run_file ) ] ) # add check script to log check_input = open( check_run_file, "r" ) logger.info( "check input file =\n{0}".format( check_input.read() ) ) check_input.close() # return check file name return check_run_file def make_process_dir( dir ): # make process directory try: os.makedirs( dir ) except OSError as e: if e.errno != errno.EEXIST: raise def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat ): # read all files into pd # function to sort out different column names def read_dat( dat, var ): # different columns names of each dat file if var == "cc": cols = [ "d(nm)", "cc", "nref", "d", "min", "max" ] elif var == "ccstar": cols = [ "1(nm)", "ccstar", "nref", "d", "min", "max" ] elif var == "mult": cols = [ "d(nm)", "nref", "poss", "comp", "obs", "mult", "snr", "I", "d", "min", "max" ] elif var == "rsplit": cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ] elif var == "wilson": cols = [ "bin", "s2", "d", "lnI", "nref" ] df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" ) return df # make df cc_df = read_dat( cc_dat, "cc" ) ccstar_df = read_dat( ccstar_dat, "ccstar" ) mult_df = read_dat( mult_dat, "mult" ) rsplit_df = read_dat( rsplit_dat, "rsplit" ) wilson_df = read_dat( wilson_dat, "wilson" ) # remove unwanted cols cc_df = cc_df[ [ "cc" ] ] ccstar_df = ccstar_df[ [ "ccstar" ] ] rsplit_df = rsplit_df[ [ "rsplit" ] ] wilson_df = wilson_df[ [ "lnI" ] ] # merge dfs stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df, wilson_df ], axis=1, join="inner" ) # make 1/d, 1/d^2 column stats_df[ "1_d" ] = 1 / stats_df.d stats_df[ "1_d2" ] = 1 / stats_df.d**2 # change nan to 0 stats_df = stats_df.fillna(0) return stats_df def get_metric( d2_series, cc_series, cut_off ): # Define the tanh function from scitbx def tanh(x, r, s0): z = (x - s0)/r return 0.5 * ( 1 - np.tanh(z) ) def arctanh( y, r, s0 ): return r * np.arctanh( 1 - 2*y ) + s0 # Fit the tanh to the data params, covariance = curve_fit( tanh, d2_series, cc_series ) # Extract the fitted parameters r, s0 = params # calculate cut-off point cc_stat = arctanh( cut_off, r, s0 ) # covert back from 1/d2 to d cc_stat = np.sqrt( ( 1 / cc_stat ) ) # get curve for plotting cc_tanh = tanh( d2_series, r, s0 ) return round( cc_stat, 3 ), cc_tanh def get_overall_cc(): # open cc log file cc_log_file = open( "cc.log" ) cc_log = cc_log_file.read() # regex example = Overall CC = 0.5970865 overcc_pattern = r"Overall\sCC\s=\s(\d\.\d+)" try: overcc = re.search( overcc_pattern, cc_log ).group(1) except AttributeError as e: overcc = np.nan return overcc def get_overall_rsplit(): # open rsplit log file rsplit_log_file = open( "Rsplit.log" ) rsplit_log = rsplit_log_file.read() # regex example = Overall Rsplit = 54.58 % overrsplit_pattern = r"Overall\sRsplit\s=\s(\d+\.\d+)" try: overrsplit = re.search( overrsplit_pattern, rsplit_log ).group(1) except AttributeError as e: overrsplit = np.nan return overrsplit def get_b(): # open rsplit log file wilson_log_file = open( "wilson.log" ) wilson_log = wilson_log_file.read() # regex example = B = 41.63 A^2 b_factor_pattern = r"B\s=\s(\d+\.\d+)\sA" try: b_factor = re.search( b_factor_pattern, wilson_log ).group(1) except AttributeError as e: b_factor = np.nan return b_factor def get_globals(): # open rsplit log file check_log_file = open( "check_hkl.log" ) check_log = check_log_file.read() # regex example => Overall = 9.299098 # regex example => Overall redundancy = 577.663604 measurements # regex example => Overall completeness = 97.852126 % snr_pattern = r"Overall\s\\s=\s(\d+\.\d+)" mult_pattern = r"Overall\sredundancy\s=\s(\d+\.\d+)\smeasurements" comp_pattern = r"Overall\scompleteness\s=\s(\d+\.\d+)" try: snr = re.search( snr_pattern, check_log ).group(1) mult = re.search( mult_pattern, check_log ).group(1) comp = re.search( comp_pattern, check_log ).group(1) except AttributeError as e: snr = np.nan mult = np.nan comp = np.nan return mult, snr, comp def summary_fig( stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut ): def dto1_d( x ): return 1/x # plot results cc_fig, axs = plt.subplots(2, 2) cc_fig.suptitle( "cc and cc* vs resolution" ) # cc plot color = "tab:red" axs[0,0].set_xlabel( "1/d (1/A)" ) axs[0,0].set_ylabel( "CC", color=color ) axs[0,0].set_ylim( 0, 1 ) axs[0,0].axhline( y = 0.3, color="black", linestyle = "dashed" ) # plot cc axs[0,0].plot( stats_df[ "1_d" ], stats_df.cc, color=color ) # plot fit axs[0,0].plot( stats_df[ "1_d" ], cc_tanh, color="tab:grey", linestyle = "dashed" ) sax1 = axs[0,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) ) sax1.set_xlabel('d (A)') axs[0,0].tick_params( axis="y", labelcolor=color ) axs[0,0].text( 0.1, 0.1, "CC0.5 @ 0.3 = {0}".format( cc_cut ), fontsize = 8 ) # cc* plot color = "tab:blue" axs[0,1].set_xlabel( "1/d (1/A)" ) axs[0,1].set_ylabel( "CC*", color=color ) axs[0,1].set_ylim( 0, 1 ) axs[0,1].axhline( y = 0.7, color="black", linestyle = "dashed" ) axs[0,1].plot( stats_df[ "1_d" ], stats_df.ccstar, color=color ) # plot fit axs[0,1].plot( stats_df[ "1_d" ], ccstar_tanh, color="tab:grey", linestyle = "dashed" ) sax2 = axs[0,1].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) ) sax2.set_xlabel('d (A)') axs[0,1].tick_params( axis='y', labelcolor=color ) axs[0,1].text( 0.1, 0.1, "CC* @ 0.7 = {0}".format( ccstar_cut ) , fontsize = 8 ) # rsplit plot color = "tab:green" axs[1,0].set_xlabel( "1/d (1/A)" ) axs[1,0].set_ylabel( "Rsplit", color=color ) axs[1,0].plot( stats_df[ "1_d" ], stats_df.rsplit, color=color ) sax3 = axs[1,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) ) sax3.set_xlabel( 'd (A)' ) axs[1,0].tick_params( axis='y', labelcolor=color ) # wilson plot color = "tab:purple" axs[1,1].set_xlabel( "1/d**2 (1/A**2)" ) axs[1,1].set_ylabel( "lnI", color=color ) axs[1,1].plot( stats_df[ "1_d2" ], stats_df.lnI, color=color ) axs[1,1].tick_params( axis='y', labelcolor=color ) # save figure plt.tight_layout() plt.savefig( "plots.png" ) def make_mtz( hklout_file, mtzout_file, cell_constants, spacegroup, residues, res_range ): # make_mtz file name mtz_run_file = "make_mtz.sh" # make F file name Fout_file = os.path.splitext( mtzout_file )[0] + "_F.mtz" # write file mtz_sh = open( mtz_run_file, "w" ) mtz_sh.write( "#!/bin/sh\n\n" ) mtz_sh.write( "module purge\n" ) mtz_sh.write( "module load ccp4/8.0\n\n" ) mtz_sh.write( "f2mtz HKLIN {0} HKLOUT {1} << EOF_hkl > f2mtz.log\n".format( hklout_file, mtzout_file ) ) mtz_sh.write( "TITLE Reflections from CrystFEL\n" ) mtz_sh.write( "NAME PROJECT {0} CRYSTAL {1} DATASET {2}\n".format( "SwissFEL", "XTAL", "DATA" ) ) mtz_sh.write( "CELL {0} {1} {2} {3} {4} {5}\n".format( cell_constants[0], cell_constants[1], cell_constants[2], cell_constants[3], cell_constants[4], cell_constants[5] ) ) mtz_sh.write( "SYMM {0}\n".format( spacegroup ) ) mtz_sh.write( "SKIP 3\n" ) mtz_sh.write( "LABOUT H K L I_stream SIGI_stream\n" ) mtz_sh.write( "CTYPE H H H J Q\n" ) mtz_sh.write( "FORMAT '(3(F4.0,1X),F10.2,10X,F10.2)'\n" ) mtz_sh.write( "SKIP 3\n" ) mtz_sh.write( "EOF_hkl\n\n\n" ) mtz_sh.write( "echo 'done'\n" ) mtz_sh.write( "echo 'I and SIGI from CrystFEL stream saved as I_stream and SIGI_stream'\n" ) mtz_sh.write( "echo 'I filename = {0}'\n\n\n".format( mtzout_file ) ) mtz_sh.write( "echo 'running truncate'\n" ) mtz_sh.write( "echo 'setting resolution range to {0}-{1}'\n".format( res_range[0], res_range[1] ) ) mtz_sh.write( "echo 'assuming that there are {0} residues in assymetric unit'\n\n\n".format( residues ) ) mtz_sh.write( "truncate HKLIN {0} HKLOUT {1} << EOF_F > truncate.log\n".format( mtzout_file, Fout_file ) ) mtz_sh.write( "truncate YES\n" ) mtz_sh.write( "anomalous NO\n" ) mtz_sh.write( "nresidue {0}\n".format( residues ) ) mtz_sh.write( "resolution {0} {1}\n".format( res_range[0], res_range[1] ) ) mtz_sh.write( "plot OFF\n" ) mtz_sh.write( "labin IMEAN=I_stream SIGIMEAN=SIGI_stream\n" ) mtz_sh.write( "labout F=F_stream SIGF=SIGF_stream\n" ) mtz_sh.write( "end\n" ) mtz_sh.write( "EOF_F\n\n\n" ) mtz_sh.write( "echo 'done'\n" ) mtz_sh.write( "echo 'I_stream and SIGI_stream from f2mtz converted to F_stream and F_stream'\n" ) mtz_sh.write( "echo 'F filename = {0} (contains both Is and Fs)'".format( Fout_file ) ) mtz_sh.close() # make file executable subprocess.call( [ "chmod", "+x", "{0}".format( mtz_run_file ) ] ) # run subprocess.call( [ "./{0}".format( mtz_run_file ) ] ) def cut_hkl_file( hklin_file, hklout_file ): # setup hklout = open( hklout_file, 'w') collect_lines = True # Open the input file for reading with open( hklin_file, 'r') as f: for line in f: if line.strip() == 'End of reflections': collect_lines = False # Stop collecting lines if collect_lines: hklout.write( line ) hklout.close() def write_mtz( hklin_file, hklout_file, mtzout, cell_constants, spacegroup, residues, res_range ): # remove final lines from crystfel hkl out print( "removing final lines from crystfel hklin" ) cut_hkl_file( hklin_file, hklout_file ) print( "done" ) # running make mtz print( "making mtz" ) print( "using cell constants\n{0} {1} {2} A {3} {4} {5} deg".format( cell_constants[0], cell_constants[1], cell_constants[2], cell_constants[3], cell_constants[4], cell_constants[5] )) make_mtz( hklout_file, mtzout, cell_constants, spacegroup, residues, res_range ) print( "done" ) # remove *cut.hkl out cwd = os.getcwd() files = os.listdir( cwd ) for file in files: if file.endswith( "cut.hkl" ): os.remove( file ) def read_cell( cell_file ): # function to get cell parameter def get_parameter( line, parameter ): # general parameter search pattern = r"{0}\s=\s(\d+\.\d+)".format( parameter ) value = re.search( pattern, line ).group(1) return value # loop through file and get parameters cell = open( cell_file, "r" ) for line in cell: if line.startswith( "a " ): a = float( get_parameter( line, "a" ) ) if line.startswith( "b " ): b = float( get_parameter( line, "b" ) ) if line.startswith( "c " ): c = float( get_parameter( line, "c" ) ) if line.startswith( "al " ): alpha = float( get_parameter( line, "al" ) ) if line.startswith( "be " ): beta = float( get_parameter( line, "be" ) ) if line.startswith( "ga " ): gamma = float( get_parameter( line, "ga" ) ) cell_constants = [ a, b, c, alpha, beta, gamma ] return cell_constants def main( cwd, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu, spacegroup, residues, reservation ): # submitted job set submitted_job_ids = set() part_dir = "{0}/{1}".format( cwd, name ) # make part directories make_process_dir( part_dir ) # move to part directory os.chdir( part_dir ) print( "making partialator file" ) # make partialator run file part_run_file = run_partialator( part_dir, name, stream, pointgroup, model, iterations, adu ) check_run_file = run_compare_check( part_dir, name, cell, shells, part_h_res ) # submit job job_id = submit_job( part_run_file, reservation ) print( f"job submitted: {0}".format( job_id ) ) submitted_job_ids.add( job_id ) # use progress bar to track job completion time.sleep(10) wait_for_jobs( submitted_job_ids, 1 ) print( "slurm processing done" ) # now run the check and compare scripts print( "running check/compare" ) check_job_id = submit_job( check_run_file, reservation ) print( f"job submitted: {0}".format( check_job_id ) ) submitted_job_ids.add( check_job_id ) time.sleep(10) wait_for_jobs( submitted_job_ids, 1 ) print( "done" ) # stats files names cc_dat = "cc.dat" ccstar_dat = "ccstar.dat" mult_dat = "mult.dat" rsplit_dat = "Rsplit.dat" wilson_dat = "wilson.dat" # make summary data table stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat ) logger.info( "stats table from .dat file =\n{0}".format( stats_df.to_string() ) ) print_df = stats_df[ [ "1_d", "d", "min", "max", "nref", "poss", "comp", "obs", "mult", "snr", "I", "rsplit", "cc", "ccstar" ] ] print_df.to_csv( "summary_table.csv", sep="\t", index=False ) # calculate cc metrics cc_cut, cc_tanh = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 ) ccstar_cut, ccstar_tanh = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 ) print( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) ) print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) ) logger.info( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) ) logger.info( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) ) # scrub other metrics overcc = get_overall_cc() overrsplit = get_overall_rsplit() b_factor = get_b() overall_mult, overall_snr, overall_comp = get_globals() logger.info( "overall CC0.5 = {0}".format( overcc ) ) logger.info( "overall Rsplit = {0}".format( overrsplit ) ) logger.info( "overall B = {0}".format( b_factor ) ) logger.info( "overall mult = {0}".format( overall_mult ) ) # show plots summary_fig( stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut ) # make mtz hklin_file = "{0}.hkl".format( name ) hklout_file = "{0}_cut.hkl".format( name ) mtzout = "{0}.mtz".format( name ) res_range = ( 50.0, cc_cut ) cell_constants = read_cell( cell ) write_mtz( hklin_file, hklout_file, mtzout, cell_constants, spacegroup, residues, res_range ) # move back to top dir os.chdir( cwd ) def list_of_floats(arg): return list(map(float, arg.split(','))) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-n", "--name", help="name of partialator run, also name of folder where data will be processed.", type=str, required=True ) parser.add_argument( "-s", "--stream_file", help="path to stream file", type=os.path.abspath, required=True ) parser.add_argument( "-p", "--pointgroup", help="pointgroup used by CrystFEL for partialator run", type=str, required=True ) parser.add_argument( "-m", "--model", help="model used partialator, e.g., unity or xsphere. Default = unity.", type=str, default="unity" ) parser.add_argument( "-i", "--iterations", help="number of iterations used for partialator run. Default = 1.", type=int, default=1 ) parser.add_argument( "-c", "--cell_file", help="path to CrystFEL cell file for partialator.", type=os.path.abspath, required=True ) parser.add_argument( "-b", "--bins", help="number of resolution bins to use. Should be more than 20. Default = 20.", type=int, default=20 ) parser.add_argument( "-r", "--resolution", help="high res limit - need something here. Default set to 1.3.", type=float, default=1.3 ) parser.add_argument( "-a", "--max_adu", help="maximum detector counts to allow. Default is 12000.", type=int, default=12000 ) parser.add_argument( "-g", "--spacegroup", help="spacegroup for making mtz, e.g P41212", type=str, required=True ) parser.add_argument( "-R", "--residues", help="number of residues for truncate, e.g., hewl = 129", type=int, required=True ) parser.add_argument( "-v", "--reservation", help="reservation name for ra cluster. Usually along the lines of P11111_2024-12-10", type=str, default=None ) parser.add_argument( "-d", "--debug", help="output debug to terminal.", type=bool, default=False ) args = parser.parse_args() # set loguru if not args.debug: logger.remove() logfile = "{0}.log".format( args.name ) logger.add( logfile, format="{message}", level="INFO") # run main cwd = os.getcwd() print( "top working directory = {0}".format( cwd ) ) main( cwd, args.name, args.stream_file, args.pointgroup, args.model, args.iterations, args.cell_file, args.bins, args.resolution, args.max_adu, args.spacegroup, args.residues, args.reservation )