now works with improved stats and does not overwrite

This commit is contained in:
Beale John Henry
2024-02-01 13:29:49 +01:00
parent 9b408bd2e4
commit 9ffee8ff7d

View File

@@ -38,7 +38,7 @@ from scipy.optimize import curve_fit
def submit_job( job_file ): def submit_job( job_file ):
# submit the job # submit the job
submit_cmd = ["sbatch", "--cpus-per-task=32", "--" ,job_file] submit_cmd = ["sbatch", "-p", "day" ,"--cpus-per-task=32", "--" ,job_file]
job_output = subprocess.check_output(submit_cmd) job_output = subprocess.check_output(submit_cmd)
# scrub job id from - example Submitted batch job 742403 # scrub job id from - example Submitted batch job 742403
@@ -59,7 +59,7 @@ def wait_for_jobs( job_ids, total_jobs ):
completed_jobs.add(job_id) completed_jobs.add(job_id)
pbar.update(1) pbar.update(1)
job_ids.difference_update(completed_jobs) job_ids.difference_update(completed_jobs)
time.sleep(2) time.sleep(5)
def run_partialator( proc_dir, stream, pointgroup, model, iterations, push_res, cell, bins, part_h_res, flag ): def run_partialator( proc_dir, stream, pointgroup, model, iterations, push_res, cell, bins, part_h_res, flag ):
@@ -116,6 +116,10 @@ def run_push_res( cwd, name, stream, pointgroup, model, iterations, cell, bins,
for res in push_res_list: for res in push_res_list:
push_res_dir = "{0}/{1}/push-res_{2}".format( cwd, name, res ) push_res_dir = "{0}/{1}/push-res_{2}".format( cwd, name, res )
# check to see if directories already made
if not os.path.exists( push_res_dir ):
# make push-res directories # make push-res directories
make_process_dir( push_res_dir ) make_process_dir( push_res_dir )
@@ -133,13 +137,12 @@ def run_push_res( cwd, name, stream, pointgroup, model, iterations, cell, bins,
# move back to top dir # move back to top dir
os.chdir( cwd ) os.chdir( cwd )
print( "DONE" )
# use progress bar to track job completion # use progress bar to track job completion
time.sleep(30)
wait_for_jobs(submitted_job_ids, len(push_res_list) ) wait_for_jobs(submitted_job_ids, len(push_res_list) )
print("slurm processing done") print("slurm processing done")
def summary_stats( cc_dat, ccstar_dat, mult_dat ): def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat ):
# read all files into pd # read all files into pd
# function to sort out different column names # function to sort out different column names
@@ -153,21 +156,27 @@ def summary_stats( cc_dat, ccstar_dat, mult_dat ):
elif var == "mult": elif var == "mult":
cols = [ "d(nm)", "nref", "poss", "comp", "obs", cols = [ "d(nm)", "nref", "poss", "comp", "obs",
"mult", "snr", "I", "d", "min", "max" ] "mult", "snr", "I", "d", "min", "max" ]
elif var == "rsplit":
cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ]
df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" ) df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" )
return df return df
# make df # make df
cc_df = read_dat( cc_dat, "cc" ) cc_df = read_dat( cc_dat, "cc" )
ccstar_df = read_dat( ccstar_dat, "ccstar" ) ccstar_df = read_dat( ccstar_dat, "ccstar" )
mult_df = read_dat( mult_dat, "mult" ) mult_df = read_dat( mult_dat, "mult" )
rsplit_df = read_dat( rsplit_dat, "rsplit" )
# remove unwanted cols # remove unwanted cols
cc_df = cc_df[ [ "cc" ] ] cc_df = cc_df[ [ "cc" ] ]
ccstar_df = ccstar_df[ [ "ccstar" ] ] ccstar_df = ccstar_df[ [ "ccstar" ] ]
rsplit_df = rsplit_df[ [ "rsplit" ] ]
# merge dfs # merge dfs
stats_df = pd.concat( [ mult_df, cc_df, ccstar_df], axis=1, join="inner" ) stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df ], axis=1, join="inner" )
# make 1/d, 1/d^2 column # make 1/d, 1/d^2 column
stats_df[ "1_d" ] = 1 / stats_df.d stats_df[ "1_d" ] = 1 / stats_df.d
@@ -177,7 +186,7 @@ def summary_stats( cc_dat, ccstar_dat, mult_dat ):
stats_df = stats_df[ [ "1_d", "1_d2", "d", "min", stats_df = stats_df[ [ "1_d", "1_d2", "d", "min",
"max", "nref", "poss", "max", "nref", "poss",
"comp", "obs", "mult", "comp", "obs", "mult",
"snr", "I", "cc", "ccstar"] ] "snr", "I", "cc", "ccstar", "rsplit" ] ]
# change nan to 0 # change nan to 0
stats_df = stats_df.fillna(0) stats_df = stats_df.fillna(0)
@@ -208,30 +217,48 @@ def get_metric( d2_series, cc_series, cut_off ):
return cc_stat return cc_stat
def summary_cc_fig( stats_df ): def summary_fig( stats_df, file_name ):
# plot results # plot results
cc_fig, (ax1, ax2) = plt.subplots(1, 2) cc_fig, axs = plt.subplots(2, 2)
cc_fig.suptitle( "cc and cc* vs resolution" ) cc_fig.suptitle( "cc and cc* vs resolution" )
# CC # cc plot
color = "tab:red" color = "tab:red"
ax1.set_xlabel( "1/d2 (1/A)" ) axs[0,0].set_xlabel( "1/d (1/A)" )
ax1.set_ylabel("CC" ) axs[0,0].set_ylabel("CC" )
ax1.axhline(y = 0.3, color="black", linestyle = "dashed") axs[0,0].set_ylim( 0, 1 )
ax1.plot(stats_df[ "1_d" ], stats_df.cc, color=color) axs[0,0].axhline(y = 0.3, color="black", linestyle = "dashed")
ax1.tick_params(axis="y", labelcolor=color) axs[0,0].plot(stats_df[ "1_d" ], stats_df.cc, color=color)
axs[0,0].tick_params(axis="y", labelcolor=color)
# CC* # cc* plot
color = "tab:blue" color = "tab:blue"
ax2.set_xlabel( "1/d (1/A)" ) axs[0,1].set_xlabel( "1/d (1/A)" )
ax2.set_ylabel("CC*", color=color) axs[0,1].set_ylabel("CC*", color=color)
ax2.axhline(y = 0.7, color="black", linestyle = "dashed") axs[0,1].set_ylim( 0, 1 )
ax2.plot(stats_df[ "1_d" ], stats_df.ccstar, color=color) axs[0,1].axhline(y = 0.7, color="black", linestyle = "dashed")
ax2.tick_params(axis='y', labelcolor=color) axs[0,1].plot(stats_df[ "1_d" ], stats_df.ccstar, color=color)
axs[0,1].tick_params(axis='y', labelcolor=color)
plot_file = "final_cc.png" # rsplit plot
plt.savefig( plot_file ) color = "tab:green"
axs[1,0].set_xlabel( "1/d (1/A)" )
axs[1,0].set_ylabel("Rsplit", color=color)
axs[1,0].plot(stats_df[ "1_d" ], stats_df.rsplit, color=color)
axs[1,0].tick_params(axis='y', labelcolor=color)
# rsplit plot
color = "tab:purple"
axs[1,1].set_xlabel( "1/d (1/A)" )
axs[1,1].set_ylabel("Multiplicity", color=color)
axs[1,1].plot(stats_df[ "1_d" ], stats_df.mult, color=color)
axs[1,1].tick_params(axis='y', labelcolor=color)
# save figure
plt.tight_layout()
plt.savefig( "{0}/plots.png".format( file_name ) )
def cc_cut_fig( analysis_df, cc_push_res, ccstar_push_res, plot_pwd ): def cc_cut_fig( analysis_df, cc_push_res, ccstar_push_res, plot_pwd ):
@@ -267,9 +294,10 @@ def push_res_analysis( cwd, name ):
cc_dat = "{0}/cc.dat".format( dir ) cc_dat = "{0}/cc.dat".format( dir )
ccstar_dat = "{0}/ccstar.dat".format( dir ) ccstar_dat = "{0}/ccstar.dat".format( dir )
mult_dat = "{0}/mult.dat".format( dir ) mult_dat = "{0}/mult.dat".format( dir )
rsplit_dat = "{0}/Rsplit.dat".format( dir )
# make summary data table # make summary data table
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat ) stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat )
# calculate cc metrics # calculate cc metrics
cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 ) cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
@@ -314,15 +342,22 @@ def push_res_analysis( cwd, name ):
def main( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res ): def main( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res ):
# run push-res scan # run push-res scan
print( "begin push-res scan" )
run_push_res( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res ) run_push_res( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_res )
# run analysis # run analysis
push_res, high_res = push_res_analysis( cwd, name ) push_res, high_res = push_res_analysis( cwd, name )
print( "done" )
print( "best push-res = {0} and a resolution of {1}".format( push_res, high_res ) )
# make final run file # make final run file
final_dir = "{0}/{1}/final".format( cwd, name ) final_dir = "{0}/{1}/final".format( cwd, name )
make_process_dir( final_dir ) make_process_dir( final_dir )
print( "rerunning partialator with {0} A cut off".format( high_res ) )
# check to see if directories already made
if not os.path.exists( final_dir ):
# move to push-res directory # move to push-res directory
os.chdir( final_dir ) os.chdir( final_dir )
@@ -336,22 +371,26 @@ def main( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_r
submitted_job_ids.add( job_id ) submitted_job_ids.add( job_id )
# use progress bar to track job completion # use progress bar to track job completion
time.sleep(30)
wait_for_jobs(submitted_job_ids, 1 ) wait_for_jobs(submitted_job_ids, 1 )
print("slurm processing done") print("slurm processing done")
print( "done" )
# stats files names # stats files names
cc_dat = "cc.dat".format( dir ) cc_dat = "{0}/cc.dat".format( final_dir )
ccstar_dat = "ccstar.dat".format( dir ) ccstar_dat = "{0}/ccstar.dat".format( final_dir )
mult_dat = "mult.dat".format( dir ) mult_dat = "{0}/mult.dat".format( final_dir )
rsplit_dat = "{0}/Rsplit.dat".format( final_dir )
# make summary data table # make summary data table
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat ) print( "calculating statistics" )
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat )
print( stats_df.to_string() ) print( stats_df.to_string() )
print_df = stats_df[ [ "1_d", "d", "min", print_df = stats_df[ [ "1_d", "d", "min",
"max", "nref", "poss", "max", "nref", "poss",
"comp", "obs", "mult", "comp", "obs", "mult",
"snr", "I", "cc", "ccstar"] ] "snr", "I", "rsplit", "cc", "ccstar"] ]
print_df.to_csv( "summary_table.csv", sep="\t", index=False ) print_df.to_csv( "{0}/summary_table.csv".format( final_dir ), sep="\t", index=False )
# calculate cc metrics # calculate cc metrics
cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 ) cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
@@ -360,10 +399,11 @@ def main( cwd, name, stream, pointgroup, model, iterations, cell, bins, part_h_r
print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) ) print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
# show plots # show plots
summary_cc_fig( stats_df ) summary_fig( stats_df, final_dir )
# move back to top dir # move back to top dir
os.chdir( cwd ) os.chdir( cwd )
print( "done" )
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -378,13 +418,15 @@ if __name__ == "__main__":
"-s", "-s",
"--stream_file", "--stream_file",
help="path to stream file", help="path to stream file",
type=os.path.abspath type=os.path.abspath,
required=True
) )
parser.add_argument( parser.add_argument(
"-p", "-p",
"--pointgroup", "--pointgroup",
help="pointgroup used by CrystFEL for partialator run", help="pointgroup used by CrystFEL for partialator run",
type=str type=str,
required=True
) )
parser.add_argument( parser.add_argument(
"-m", "-m",