Files
crystfel_tools/reduction_tools/partialator.py
2024-01-31 11:06:17 +01:00

375 lines
12 KiB
Python

#!/usr/bin/python
# author J.Beale
"""
# aim
to merge .stream files and calculate statistics
# usage
python partialator.py -s <path-to-stream-file>
-n name (name of job - default = partialator)
-p pointgroup
-m model (unity or xsphere - default is unity)
-i iterations - number of iterations in partialator
-c <path-to-cell-file>
-b number of resolution bins - must be > 20
-r high-res limt. Needs a default. Default set to 1.3
-a max-adu. Default = 12000
# output
- scaled/merged files
- an mtz file
- useful plots
- useful summerized .dat files
"""
# modules
import pandas as pd
import numpy as np
import subprocess
import os, errno
import time
import argparse
from tqdm import tqdm
import regex as re
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
def submit_job( job_file ):
# submit the job
submit_cmd = ["sbatch", "--cpus-per-task=32", "--" ,job_file]
job_output = subprocess.check_output(submit_cmd)
# scrub job id from - example Submitted batch job 742403
pattern = r"Submitted batch job (\d+)"
job_id = re.search( pattern, job_output.decode().strip() ).group(1)
return int(job_id)
def wait_for_jobs( job_ids, total_jobs ):
with tqdm(total=total_jobs, desc="Jobs Completed", unit="job") as pbar:
while job_ids:
completed_jobs = set()
for job_id in job_ids:
status_cmd = ["squeue", "-h", "-j", str(job_id)]
status = subprocess.check_output(status_cmd)
if not status:
completed_jobs.add(job_id)
pbar.update(1)
job_ids.difference_update(completed_jobs)
time.sleep(2)
def run_partialator( proc_dir, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu ):
# partialator file name
part_run_file = "{0}/partialator_{1}.sh".format( proc_dir, name )
# write file
part_sh = open( part_run_file, "w" )
part_sh.write( "#!/bin/sh\n\n" )
part_sh.write( "module purge\n" )
part_sh.write( "module use MX unstable\n" )
part_sh.write( "module load crystfel/0.10.2-rhel8\n" )
part_sh.write( "partialator -i {0} \\\n".format( stream ) )
part_sh.write( " -o merged_{0}.hkl \\\n".format( name ) )
part_sh.write( " -y {0} \\\n".format( pointgroup ) )
part_sh.write( " --model={0} \\\n".format( model ) )
part_sh.write( " --max-adu={0} \\\n".format( adu ) )
part_sh.write( " --iterations={0}\n\n".format( iterations ) )
part_sh.write( "check_hkl --shell-file=mult.dat *.hkl -p {0} --nshells={1} --highres={2} &> check_hkl.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "check_hkl --ltest --ignore-negs --shell-file=ltest.dat *.hkl -p {0} --nshells={1} --highres={2} &> ltest.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "check_hkl --wilson --shell-file=wilson.dat *.hkl -p {0} --nshells={1} --highres={2} &> wilson.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "compare_hkl --fom=Rsplit --shell-file=Rsplit.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> Rsplit.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "compare_hkl --fom=cc --shell-file=cc.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> cc.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "compare_hkl --fom=ccstar --shell-file=ccstar.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> ccstar.log\n".format( cell, shells, part_h_res ) )
part_sh.close()
# make file executable
subprocess.call( [ "chmod", "+x", "{0}".format( part_run_file ) ] )
# return partialator file name
return part_run_file
def make_process_dir( dir ):
# make process directory
try:
os.makedirs( dir )
except OSError as e:
if e.errno != errno.EEXIST:
raise
def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat ):
# read all files into pd
# function to sort out different column names
def read_dat( dat, var ):
# different columns names of each dat file
if var == "cc":
cols = [ "d(nm)", "cc", "nref", "d", "min", "max" ]
elif var == "ccstar":
cols = [ "1(nm)", "ccstar", "nref", "d", "min", "max" ]
elif var == "mult":
cols = [ "d(nm)", "nref", "poss", "comp", "obs",
"mult", "snr", "I", "d", "min", "max" ]
elif var == "rsplit":
cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ]
df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" )
print(df)
return df
# make df
cc_df = read_dat( cc_dat, "cc" )
ccstar_df = read_dat( ccstar_dat, "ccstar" )
mult_df = read_dat( mult_dat, "mult" )
rsplit_df = read_dat( rsplit_dat, "rsplit" )
# remove unwanted cols
cc_df = cc_df[ [ "cc" ] ]
ccstar_df = ccstar_df[ [ "ccstar" ] ]
rsplit_df = rsplit_df[ [ "rsplit" ] ]
# merge dfs
stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df ], axis=1, join="inner" )
# make 1/d, 1/d^2 column
stats_df[ "1_d" ] = 1 / stats_df.d
stats_df[ "1_d2" ] = 1 / stats_df.d**2
# reorder cols
stats_df = stats_df[ [ "1_d", "1_d2", "d", "min",
"max", "nref", "poss",
"comp", "obs", "mult",
"snr", "I", "cc", "ccstar", "rsplit" ] ]
# change nan to 0
stats_df = stats_df.fillna(0)
return stats_df
def get_metric( d2_series, cc_series, cut_off ):
# Define the tanh function from scitbx
def tanh(x, r, s0):
z = (x - s0)/r
return 0.5 * (1 - np.tanh(z))
def arctanh( y, r, s0 ):
return r * np.arctanh( 1 - 2*y ) + s0
# Fit the tanh to the data
params, covariance = curve_fit( tanh, d2_series, cc_series )
# Extract the fitted parameters
r, s0 = params
# calculate cut-off point
cc_stat = arctanh( cut_off, r, s0 )
# covert back from 1/d2 to d
cc_stat = np.sqrt( ( 1 / cc_stat ) )
return cc_stat
def summary_fig( stats_df ):
# plot results
cc_fig, axs = plt.subplots(2, 2)
cc_fig.suptitle( "cc and cc* vs resolution" )
# cc plot
color = "tab:red"
axs[0,0].set_xlabel( "1/d (1/A)" )
axs[0,0].set_ylabel("CC" )
axs[0,0].set_ylim( 0, 1 )
axs[0,0].axhline(y = 0.3, color="black", linestyle = "dashed")
axs[0,0].plot(stats_df[ "1_d" ], stats_df.cc, color=color)
axs[0,0].tick_params(axis="y", labelcolor=color)
# cc* plot
color = "tab:blue"
axs[0,1].set_xlabel( "1/d (1/A)" )
axs[0,1].set_ylabel("CC*", color=color)
axs[0,1].set_ylim( 0, 1 )
axs[0,1].axhline(y = 0.7, color="black", linestyle = "dashed")
axs[0,1].plot(stats_df[ "1_d" ], stats_df.ccstar, color=color)
axs[0,1].tick_params(axis='y', labelcolor=color)
# rsplit plot
color = "tab:green"
axs[1,0].set_xlabel( "1/d (1/A)" )
axs[1,0].set_ylabel("Rsplit", color=color)
axs[1,0].plot(stats_df[ "1_d" ], stats_df.rsplit, color=color)
axs[1,0].tick_params(axis='y', labelcolor=color)
# rsplit plot
color = "tab:purple"
axs[1,1].set_xlabel( "1/d (1/A)" )
axs[1,1].set_ylabel("Multiplicity", color=color)
axs[1,1].plot(stats_df[ "1_d" ], stats_df.mult, color=color)
axs[1,1].tick_params(axis='y', labelcolor=color)
# save figure
plt.tight_layout()
plt.savefig("plots.png")
def get_mean_cell( stream ):
# get uc values from stream file
# example - Cell parameters 7.71784 7.78870 3.75250 nm, 90.19135 90.77553 90.19243 deg
# scrub clen and return - else nan
try:
pattern = r"Cell\sparameters\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\snm,\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\sdeg"
cell_lst = re.findall( pattern, stream )
xtals = len( cell_lst )
except AttributeError:
return np.nan
cols = [ "a", "b", "c", "alpha", "beta", "gamma" ]
cell_df = pd.DataFrame( cell_lst, columns=cols )
mean_a = round( cell_df.a.mean()*10, 3 )
mean_b = round( cell_df.b.mean()*10, 3 )
mean_c = round( cell_df.c.mean()*10, 3 )
mean_alpha = round( cell_df.alpha.mean(), 3 )
mean_beta = round( cell_df.beta.mean(), 3 )
mean_gamma = round( cell_df.gamma.mean(), 3 )
return mean_a, mean_b, mean_c, mean_alpha, mean_beta, mean_gamma
def main( cwd, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu ):
print( "begin job" )
# submitted job set
submitted_job_ids = set()
part_dir = "{0}/{1}".format( cwd, name )
# make part directories
make_process_dir( part_dir )
# move to part directory
os.chdir( part_dir )
print( "making partialator files" )
# make partialator run file
part_run_file = run_partialator( part_dir, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu )
# submit job
job_id = submit_job( part_run_file )
print(f"job submitted: {0}".format( job_id ) )
submitted_job_ids.add( job_id )
print( "DONE" )
# use progress bar to track job completion
wait_for_jobs(submitted_job_ids, 1 )
print("slurm processing done")
# stats files names
cc_dat = "cc.dat"
ccstar_dat = "ccstar.dat"
mult_dat = "mult.dat"
rsplit_dat = "Rsplit.dat"
# make summary data table
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat )
print( stats_df.to_string() )
print_df = stats_df[ [ "1_d", "d", "min",
"max", "nref", "poss",
"comp", "obs", "mult",
"snr", "I", "cc", "ccstar"] ]
print_df.to_csv( "summary_table.csv", sep="\t", index=False )
# calculate cc metrics
cc_cut = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
ccstar_cut = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 )
print( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) )
print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
# show plots
summary_fig( stats_df )
# move back to top dir
os.chdir( cwd )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-n",
"--name",
help="name of partialator run, also name of folder where data will be processed.",
type=str,
required=True
)
parser.add_argument(
"-s",
"--stream_file",
help="path to stream file",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-p",
"--pointgroup",
help="pointgroup used by CrystFEL for partialator run",
type=str,
required=True
)
parser.add_argument(
"-m",
"--model",
help="model used partialator, e.g., unity or xsphere, unity = default",
type=str,
default="unity"
)
parser.add_argument(
"-i",
"--iterations",
help="number of iterations used for partialator run. Default = 1",
type=int,
default=1
)
parser.add_argument(
"-c",
"--cell_file",
help="path to CrystFEL cell file for partialator",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-b",
"--bins",
help="number of resolution bins to use. Should be more than 20. Default = 20",
type=int,
default=20
)
parser.add_argument(
"-r",
"--resolution",
help="high res limit - need something here. Default set to 1.3",
type=float,
default=1.3
)
parser.add_argument(
"-a",
"--max_adu",
help="maximum detector counts to allow. Default is 12000",
type=int,
default=12000
)
args = parser.parse_args()
# run main
cwd = os.getcwd()
print( "top working directory = {0}".format( cwd ) )
main( cwd, args.name, args.stream_file, args.pointgroup, args.model, args.iterations, args.cell_file, args.bins, args.resolution, args.max_adu )