Files
crystfel_tools/reduction_tools/push_res_scan.py
2024-02-01 13:29:49 +01:00

470 lines
16 KiB
Python

#!/usr/bin/python
# author J.Beale
"""
# aim
to provide a complete python script to merge/scale .stream, do a push-res scan and calculate statistics
# usage
python partialator.py -s <path-to-stream-file>
-n name (name of job - default = partialator)
-p pointgroup
-m model (unity or xsphere - default is unity)
-i iterations - number of iterations in partialator
-c <path-to-cell-file>
-b number of resolution bins - must be > 20
-r high-res limt. Needs a default. Default set to 1.3
# output
- within the <name> filder a series of folders containing individual partialator runs with different push-res values
- the script uses CC and CC* as metrics for which push-res value is best.
- if these are the not the same - it looks at which values are the closest - this may not be correct!
- it then reruns partialator with the selected push-res and resolution cut-off
"""
# modules
import pandas as pd
import numpy as np
import subprocess
import os, errno
import time
import argparse
from tqdm import tqdm
import regex as re
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
def submit_job( job_file ):
# submit the job
submit_cmd = ["sbatch", "-p", "day" ,"--cpus-per-task=32", "--" ,job_file]
job_output = subprocess.check_output(submit_cmd)
# scrub job id from - example Submitted batch job 742403
pattern = r"Submitted batch job (\d+)"
job_id = re.search( pattern, job_output.decode().strip() ).group(1)
return int(job_id)
def wait_for_jobs( job_ids, total_jobs ):
with tqdm(total=total_jobs, desc="Jobs Completed", unit="job") as pbar:
while job_ids:
completed_jobs = set()
for job_id in job_ids:
status_cmd = ["squeue", "-h", "-j", str(job_id)]
status = subprocess.check_output(status_cmd)
if not status:
completed_jobs.add(job_id)
pbar.update(1)
job_ids.difference_update(completed_jobs)
time.sleep(5)
def run_partialator( proc_dir, stream, pointgroup, model, iterations, push_res, cell, bins, part_h_res, flag ):
# partialator file name
if flag == "push-res":
part_run_file = "{0}/push-res_{1}.sh".format( proc_dir, push_res )
if flag == "final":
part_run_file = "{0}/final.sh".format( proc_dir )
# write file
part_sh = open( part_run_file, "w" )
part_sh.write( "#!/bin/sh\n\n" )
part_sh.write( "module purge\n" )
part_sh.write( "module use MX unstable\n" )
part_sh.write( "module load crystfel/0.10.2-rhel8\n" )
part_sh.write( "partialator -i {0} \\\n".format( stream ) )
part_sh.write( " -o push-res_{0}.hkl \\\n".format( push_res ) )
part_sh.write( " -y {0} \\\n".format( pointgroup ) )
part_sh.write( " --model={0} \\\n".format( model ) )
part_sh.write( " --iterations={0} \\\n".format( iterations ) )
part_sh.write( " --push-res={0}\n\n".format( push_res ) )
part_sh.write( "check_hkl --shell-file=mult.dat *.hkl -p {0} --nshells={1} --highres={2} &> check_hkl.log\n".format( cell, bins, part_h_res ) )
part_sh.write( "check_hkl --ltest --ignore-negs --shell-file=ltest.dat *.hkl -p {0} --nshells={1} --highres={2} &> ltest.log\n".format( cell, bins, part_h_res ) )
part_sh.write( "check_hkl --wilson --shell-file=wilson.dat *.hkl -p {0} --nshells={1} --highres={2} &> wilson.log\n".format( cell, bins, part_h_res ) )
part_sh.write( "compare_hkl --fom=Rsplit --shell-file=Rsplit.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> Rsplit.log\n".format( cell, bins, part_h_res ) )
part_sh.write( "compare_hkl --fom=cc --shell-file=cc.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> cc.log\n".format( cell, bins, part_h_res ) )
part_sh.write( "compare_hkl --fom=ccstar --shell-file=ccstar.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> ccstar.log\n".format( cell, bins, part_h_res ) )
part_sh.close()
# make file executable
subprocess.call( [ "chmod", "+x", "{0}".format( part_run_file ) ] )
# return partialator file name
return part_run_file
def make_process_dir( dir ):
# make process directory
try:
os.makedirs( dir )
except OSError as e:
if e.errno != errno.EEXIST:
raise
def run_push_res( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res ):
# list of push-res values to try
push_res_list = np.arange( 0.0, 3.2, 0.2 )
push_res_list = np.around( push_res_list, 1 ) # need to round to ensure 1dp
# submitted job set
submitted_job_ids = set()
# make push directories and run partialator
for res in push_res_list:
push_res_dir = "{0}/{1}/push-res_{2}".format( cwd, name, res )
# check to see if directories already made
if not os.path.exists( push_res_dir ):
# make push-res directories
make_process_dir( push_res_dir )
# move to push-res directory
os.chdir( push_res_dir )
# make partialator run file
part_run_file = run_partialator( push_res_dir, stream, pointgroup, model, iterations, res, cell, bins, part_h_res, flag="push-res" )
# submit job
job_id = submit_job( part_run_file )
print( "push-res {0} job submitted: {1}".format( res, job_id ) )
submitted_job_ids.add( job_id )
# move back to top dir
os.chdir( cwd )
# use progress bar to track job completion
time.sleep(30)
wait_for_jobs(submitted_job_ids, len(push_res_list) )
print("slurm processing done")
def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat ):
# read all files into pd
# function to sort out different column names
def read_dat( dat, var ):
# different columns names of each dat file
if var == "cc":
cols = [ "d(nm)", "cc", "nref", "d", "min", "max" ]
elif var == "ccstar":
cols = [ "1(nm)", "ccstar", "nref", "d", "min", "max" ]
elif var == "mult":
cols = [ "d(nm)", "nref", "poss", "comp", "obs",
"mult", "snr", "I", "d", "min", "max" ]
elif var == "rsplit":
cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ]
df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" )
return df
# make df
cc_df = read_dat( cc_dat, "cc" )
ccstar_df = read_dat( ccstar_dat, "ccstar" )
mult_df = read_dat( mult_dat, "mult" )
rsplit_df = read_dat( rsplit_dat, "rsplit" )
# remove unwanted cols
cc_df = cc_df[ [ "cc" ] ]
ccstar_df = ccstar_df[ [ "ccstar" ] ]
rsplit_df = rsplit_df[ [ "rsplit" ] ]
# merge dfs
stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df ], axis=1, join="inner" )
# make 1/d, 1/d^2 column
stats_df[ "1_d" ] = 1 / stats_df.d
stats_df[ "1_d2" ] = 1 / stats_df.d**2
# reorder cols
stats_df = stats_df[ [ "1_d", "1_d2", "d", "min",
"max", "nref", "poss",
"comp", "obs", "mult",
"snr", "I", "cc", "ccstar", "rsplit" ] ]
# change nan to 0
stats_df = stats_df.fillna(0)
return stats_df
def get_metric( d2_series, cc_series, cut_off ):
# Define the tanh function from scitbx
def tanh(x, r, s0):
z = (x - s0)/r
return 0.5 * (1 - np.tanh(z))
def arctanh( y, r, s0 ):
return r * np.arctanh( 1 - 2*y ) + s0
# Fit the tanh to the data
params, covariance = curve_fit( tanh, d2_series, cc_series )
# Extract the fitted parameters
r, s0 = params
# calculate cut-off point
cc_stat = arctanh( cut_off, r, s0 )
# covert back from 1/d2 to d
cc_stat = np.sqrt( ( 1 / cc_stat ) )
return cc_stat
def summary_fig( stats_df, file_name ):
# plot results
cc_fig, axs = plt.subplots(2, 2)
cc_fig.suptitle( "cc and cc* vs resolution" )
# cc plot
color = "tab:red"
axs[0,0].set_xlabel( "1/d (1/A)" )
axs[0,0].set_ylabel("CC" )
axs[0,0].set_ylim( 0, 1 )
axs[0,0].axhline(y = 0.3, color="black", linestyle = "dashed")
axs[0,0].plot(stats_df[ "1_d" ], stats_df.cc, color=color)
axs[0,0].tick_params(axis="y", labelcolor=color)
# cc* plot
color = "tab:blue"
axs[0,1].set_xlabel( "1/d (1/A)" )
axs[0,1].set_ylabel("CC*", color=color)
axs[0,1].set_ylim( 0, 1 )
axs[0,1].axhline(y = 0.7, color="black", linestyle = "dashed")
axs[0,1].plot(stats_df[ "1_d" ], stats_df.ccstar, color=color)
axs[0,1].tick_params(axis='y', labelcolor=color)
# rsplit plot
color = "tab:green"
axs[1,0].set_xlabel( "1/d (1/A)" )
axs[1,0].set_ylabel("Rsplit", color=color)
axs[1,0].plot(stats_df[ "1_d" ], stats_df.rsplit, color=color)
axs[1,0].tick_params(axis='y', labelcolor=color)
# rsplit plot
color = "tab:purple"
axs[1,1].set_xlabel( "1/d (1/A)" )
axs[1,1].set_ylabel("Multiplicity", color=color)
axs[1,1].plot(stats_df[ "1_d" ], stats_df.mult, color=color)
axs[1,1].tick_params(axis='y', labelcolor=color)
# save figure
plt.tight_layout()
plt.savefig( "{0}/plots.png".format( file_name ) )
def cc_cut_fig( analysis_df, cc_push_res, ccstar_push_res, plot_pwd ):
# plot cc and cc_star resolution against push_res
plt.suptitle( "CC and CC* against push-res" )
# plot comparison between CC and CC*
plt.plot( analysis_df.index, analysis_df.cc_cut, label="cc", color="tab:red" )
plt.plot( analysis_df.index, analysis_df.ccstar_cut, label="cc*", color="tab:blue" )
plt.xlabel( "push_res" )
plt.ylabel( "resolution" )
plt.axvline( x=cc_push_res, color="tab:red", linestyle = "dashed")
plt.axvline( x=ccstar_push_res, color="tab:blue", linestyle = "dashed")
plt.legend()
plot_file = "{0}/push-res.png".format( plot_pwd )
plt.savefig( plot_file )
def push_res_analysis( cwd, name ):
# list of push-res folders to evaluate
push_res_list = np.arange( 0.0, 3.2, 0.2 )
push_res_list = np.around( push_res_list, 1 ) # need to round to ensure 1dp
cols = [ "cc_cut", "ccstar_cut" ]
analysis_df = pd.DataFrame( columns=cols, index=push_res_list, dtype=float )
# for loop through folders for analyses
for index, row in analysis_df.iterrows():
dir = "{0}/{1}/push-res_{2}".format( cwd, name, index )
# stats files names
cc_dat = "{0}/cc.dat".format( dir )
ccstar_dat = "{0}/ccstar.dat".format( dir )
mult_dat = "{0}/mult.dat".format( dir )
rsplit_dat = "{0}/Rsplit.dat".format( dir )
# make summary data table
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat )
# calculate cc metrics
cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
ccstar_cut = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 )
analysis_df.at[ index, "cc_cut" ] = float( cc_cut )
analysis_df.at[ index, "ccstar_cut" ] = float( ccstar_cut )
# save push-res results
analysis_file_name = "{0}/push-res-summary-table.csv".format( name )
analysis_df.to_csv( analysis_file_name, sep="\t" )
# find highest res cc and ccstar
cc_high = analysis_df.cc_cut.min()
ccstar_high = analysis_df.ccstar_cut.min()
cc_push_res = analysis_df[[ "cc_cut" ]].idxmin()[0]
ccstar_push_res = analysis_df[["ccstar_cut"]].idxmin()[0]
# plot push res results
plot_pwd = "{0}/{1}".format( cwd, name )
cc_cut_fig( analysis_df, cc_push_res, ccstar_push_res, plot_pwd )
# logic around which push-res to use
cc_at_ccstar_cut = analysis_df.at[ ccstar_push_res, "cc_cut" ]
ccstar_at_cc_cut = analysis_df.at[ cc_push_res, "ccstar_cut" ]
cc_high_diff = abs( cc_high - ccstar_at_cc_cut )
ccstar_high_diff = abs( ccstar_high - cc_at_ccstar_cut )
if cc_push_res == ccstar_push_res:
push_res = cc_push_res
high_res = cc_high
elif cc_high_diff > ccstar_high_diff:
push_res = ccstar_push_res
high_res = ccstar_high
elif ccstar_high_diff > cc_high_diff:
push_res = cc_push_res
high_res = cc_high
print( "use push-res = {0}".format( push_res ) )
return push_res, high_res
def main( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res ):
# run push-res scan
print( "begin push-res scan" )
run_push_res( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res )
# run analysis
push_res, high_res = push_res_analysis( cwd, name )
print( "done" )
print( "best push-res = {0} and a resolution of {1}".format( push_res, high_res ) )
# make final run file
final_dir = "{0}/{1}/final".format( cwd, name )
make_process_dir( final_dir )
print( "rerunning partialator with {0} A cut off".format( high_res ) )
# check to see if directories already made
if not os.path.exists( final_dir ):
# move to push-res directory
os.chdir( final_dir )
# re-run final push-res
submitted_job_ids = set()
final_file = run_partialator( final_dir, stream, pointgroup, model, iterations, push_res, cell, bins, part_h_res=high_res, flag="final" )
# submit job
job_id = submit_job( final_file )
print(f"final job submitted: {1}".format( job_id ) )
submitted_job_ids.add( job_id )
# use progress bar to track job completion
time.sleep(30)
wait_for_jobs(submitted_job_ids, 1 )
print("slurm processing done")
print( "done" )
# stats files names
cc_dat = "{0}/cc.dat".format( final_dir )
ccstar_dat = "{0}/ccstar.dat".format( final_dir )
mult_dat = "{0}/mult.dat".format( final_dir )
rsplit_dat = "{0}/Rsplit.dat".format( final_dir )
# make summary data table
print( "calculating statistics" )
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat )
print( stats_df.to_string() )
print_df = stats_df[ [ "1_d", "d", "min",
"max", "nref", "poss",
"comp", "obs", "mult",
"snr", "I", "rsplit", "cc", "ccstar"] ]
print_df.to_csv( "{0}/summary_table.csv".format( final_dir ), sep="\t", index=False )
# calculate cc metrics
cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
ccstar_cut = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 )
print( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) )
print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
# show plots
summary_fig( stats_df, final_dir )
# move back to top dir
os.chdir( cwd )
print( "done" )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-n",
"--name",
help="name of push-scan folder, also name of folder where data will be processed. Default = push-res",
type=str,
default="push-res"
)
parser.add_argument(
"-s",
"--stream_file",
help="path to stream file",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-p",
"--pointgroup",
help="pointgroup used by CrystFEL for partialator run",
type=str,
required=True
)
parser.add_argument(
"-m",
"--model",
help="model used partialator, e.g., unity or xsphere, unity = default",
type=str,
default="unity"
)
parser.add_argument(
"-i",
"--iterations",
help="number of iterations used for partialator run. Default = 1",
type=int,
default=1
)
parser.add_argument(
"-c",
"--cell_file",
help="path to CrystFEL cell file for partialator",
type=os.path.abspath
)
parser.add_argument(
"-b",
"--bins",
help="number of resolution bins to use. Should be more than 20. Default = 20",
type=int,
default=20
)
parser.add_argument(
"-r",
"--resolution",
help="high res limit - need something here. Default set to 1.3",
type=float,
default=1.3
)
args = parser.parse_args()
# run main
cwd = os.getcwd()
print( "top working directory = {0}".format( cwd ) )
main( cwd, args.name, args.stream_file, args.pointgroup, args.model, args.iterations, args.cell_file, args.bins, args.resolution )