Files
crystfel_tools/reduction_tools/partialator.py
2025-01-13 17:09:37 +01:00

415 lines
14 KiB
Python

#!/usr/bin/python
# author J.Beale
"""
# aim
to merge .stream files and calculate statistics
# usage
python partialator.py -s <path-to-stream-file>
-n name (name of job - default = partialator)
-p pointgroup
-m model (unity or xsphere - default is unity)
-i iterations - number of iterations in partialator
-c <path-to-cell-file>
-b number of resolution bins - must be > 20
-r high-res limt. Needs a default. Default set to 1.3
-a max-adu. Default = 12000
-R ra reservation name if available
# output
- scaled/merged files
- an mtz file
- useful plots
- useful summerized .dat files
- log file of output
"""
# modules
from sys import exit
import pandas as pd
import numpy as np
import subprocess
import os, errno
import time
import argparse
from tqdm import tqdm
import regex as re
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import warnings
warnings.filterwarnings( "ignore", category=RuntimeWarning )
from loguru import logger
def submit_job( job_file, reservation ):
# submit the job
if reservation:
print( "using a ra beamtime reservation = {0}".format( reservation ) )
logger.info( "using ra reservation to process data = {0}".format( reservation ) )
submit_cmd = [ "sbatch", "--reservation={0}".format( reservation ), "--cpus-per-task=32", "--" , job_file ]
else:
submit_cmd = [ "sbatch", "--cpus-per-task=32", "--" , job_file ]
logger.info( "using slurm command = {0}".format( submit_cmd ) )
try:
job_output = subprocess.check_output( submit_cmd )
logger.info( "submited job = {0}".format( job_output ) )
except subprocess.CalledProcessError as e:
print( "please give the correct ra reservation or remove the -R from the arguements" )
exit()
# scrub job id from - example Submitted batch job 742403
pattern = r"Submitted batch job (\d+)"
job_id = re.search( pattern, job_output.decode().strip() ).group(1)
return int(job_id)
def wait_for_jobs( job_ids, total_jobs ):
with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar:
while job_ids:
completed_jobs = set()
for job_id in job_ids:
status_cmd = [ "squeue", "-h", "-j", str( job_id ) ]
status = subprocess.check_output( status_cmd )
if not status:
completed_jobs.add( job_id )
pbar.update( 1 )
job_ids.difference_update( completed_jobs )
time.sleep( 2 )
def run_partialator( proc_dir, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu ):
# partialator file name
part_run_file = "{0}/partialator_{1}.sh".format( proc_dir, name )
# write file
part_sh = open( part_run_file, "w" )
part_sh.write( "#!/bin/sh\n\n" )
part_sh.write( "module purge\n" )
part_sh.write( "module use MX unstable\n" )
part_sh.write( "module load crystfel/0.10.2-rhel8\n" )
part_sh.write( "partialator -i {0} \\\n".format( stream ) )
part_sh.write( " -o merged_{0}.hkl \\\n".format( name ) )
part_sh.write( " -y {0} \\\n".format( pointgroup ) )
part_sh.write( " --model={0} \\\n".format( model ) )
part_sh.write( " --max-adu={0} \\\n".format( adu ) )
part_sh.write( " -j 32 \\\n" )
part_sh.write( " --iterations={0}\n\n".format( iterations ) )
part_sh.write( "check_hkl --shell-file=mult.dat *.hkl -p {0} --nshells={1} --highres={2} &> check_hkl.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "check_hkl --ltest --ignore-negs --shell-file=ltest.dat *.hkl -p {0} --nshells={1} --highres={2} &> ltest.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "check_hkl --wilson --shell-file=wilson.dat *.hkl -p {0} --nshells={1} --highres={2} &> wilson.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "compare_hkl --fom=Rsplit --shell-file=Rsplit.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> Rsplit.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "compare_hkl --fom=cc --shell-file=cc.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> cc.log\n".format( cell, shells, part_h_res ) )
part_sh.write( "compare_hkl --fom=ccstar --shell-file=ccstar.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> ccstar.log\n".format( cell, shells, part_h_res ) )
part_sh.close()
# make file executable
subprocess.call( [ "chmod", "+x", "{0}".format( part_run_file ) ] )
# add partialator script to log
part_input = open( part_run_file, "r" )
logger.info( "partialator input file =\n{0}".format( part_input.read() ) )
part_input.close()
# return partialator file name
return part_run_file
def make_process_dir( dir ):
# make process directory
try:
os.makedirs( dir )
except OSError as e:
if e.errno != errno.EEXIST:
raise
def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat ):
# read all files into pd
# function to sort out different column names
def read_dat( dat, var ):
# different columns names of each dat file
if var == "cc":
cols = [ "d(nm)", "cc", "nref", "d", "min", "max" ]
elif var == "ccstar":
cols = [ "1(nm)", "ccstar", "nref", "d", "min", "max" ]
elif var == "mult":
cols = [ "d(nm)", "nref", "poss", "comp", "obs",
"mult", "snr", "I", "d", "min", "max" ]
elif var == "rsplit":
cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ]
elif var == "wilson":
cols = [ "bin", "s2", "d", "lnI", "nref" ]
df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" )
return df
# make df
cc_df = read_dat( cc_dat, "cc" )
ccstar_df = read_dat( ccstar_dat, "ccstar" )
mult_df = read_dat( mult_dat, "mult" )
rsplit_df = read_dat( rsplit_dat, "rsplit" )
wilson_df = read_dat( wilson_dat, "wilson" )
# remove unwanted cols
cc_df = cc_df[ [ "cc" ] ]
ccstar_df = ccstar_df[ [ "ccstar" ] ]
rsplit_df = rsplit_df[ [ "rsplit" ] ]
wilson_df = wilson_df[ [ "lnI" ] ]
# merge dfs
stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df, wilson_df ], axis=1, join="inner" )
# make 1/d, 1/d^2 column
stats_df[ "1_d" ] = 1 / stats_df.d
stats_df[ "1_d2" ] = 1 / stats_df.d**2
# change nan to 0
stats_df = stats_df.fillna(0)
return stats_df
def get_metric( d2_series, cc_series, cut_off ):
# Define the tanh function from scitbx
def tanh(x, r, s0):
z = (x - s0)/r
return 0.5 * ( 1 - np.tanh(z) )
def arctanh( y, r, s0 ):
return r * np.arctanh( 1 - 2*y ) + s0
# Fit the tanh to the data
params, covariance = curve_fit( tanh, d2_series, cc_series )
# Extract the fitted parameters
r, s0 = params
# calculate cut-off point
cc_stat = arctanh( cut_off, r, s0 )
# covert back from 1/d2 to d
cc_stat = np.sqrt( ( 1 / cc_stat ) )
# get curve for plotting
cc_tanh = tanh( d2_series, r, s0 )
return round( cc_stat, 3 ), cc_tanh
def summary_fig( stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut ):
def dto1_d( x ):
return 1/x
def dto1_d2( x ):
return 1/x**2
# plot results
cc_fig, axs = plt.subplots(2, 2)
cc_fig.suptitle( "cc and cc* vs resolution" )
# cc plot
color = "tab:red"
axs[0,0].set_xlabel( "1/d (1/A)" )
axs[0,0].set_ylabel( "CC" )
axs[0,0].set_ylim( 0, 1 )
axs[0,0].axhline( y = 0.3, color="black", linestyle = "dashed" )
# plot cc
axs[0,0].plot( stats_df[ "1_d" ], stats_df.cc, color=color )
# plot fit
axs[0,0].plot( stats_df[ "1_d" ], cc_tanh, color="tab:grey", linestyle = "dashed" )
sax1 = axs[0,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) )
sax1.set_xlabel('d (A)')
axs[0,0].tick_params( axis="y", labelcolor=color )
axs[0,0].text( 0.1, 0.1, "CC @ 0.2 = {0}".format( cc_cut ), fontsize = 8 )
# cc* plot
color = "tab:blue"
axs[0,1].set_xlabel( "1/d (1/A)" )
axs[0,1].set_ylabel( "CC*", color=color )
axs[0,1].set_ylim( 0, 1 )
axs[0,1].axhline( y = 0.7, color="black", linestyle = "dashed" )
axs[0,1].plot( stats_df[ "1_d" ], stats_df.ccstar, color=color )
# plot fit
axs[0,1].plot( stats_df[ "1_d" ], ccstar_tanh, color="tab:grey", linestyle = "dashed" )
sax2 = axs[0,1].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) )
sax2.set_xlabel('d (A)')
axs[0,1].tick_params( axis='y', labelcolor=color )
axs[0,1].text( 0.1, 0.1, "CC* @ 0.7 = {0}".format( ccstar_cut ) , fontsize = 8 )
# rsplit plot
color = "tab:green"
axs[1,0].set_xlabel( "1/d (1/A)" )
axs[1,0].set_ylabel( "Rsplit", color=color )
axs[1,0].plot( stats_df[ "1_d" ], stats_df.rsplit, color=color )
sax3 = axs[1,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) )
sax3.set_xlabel('d (A)')
axs[1,0].tick_params( axis='y', labelcolor=color )
# wilson plot
color = "tab:purple"
axs[1,1].set_xlabel( "d (A)" )
axs[1,1].set_ylabel( "lnI", color=color )
axs[1,1].plot( stats_df[ "1_d2" ], stats_df.lnI, color=color )
sax4 = axs[1,1].secondary_xaxis( 'top', functions=( dto1_d2, dto1_d2 ) )
sax4.set_xlabel('d (A)')
axs[1,1].tick_params( axis='y', labelcolor=color )
# save figure
plt.tight_layout()
plt.savefig( "plots.png" )
def main( cwd, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu, reservation ):
# submitted job set
submitted_job_ids = set()
part_dir = "{0}/{1}".format( cwd, name )
# make part directories
make_process_dir( part_dir )
# move to part directory
os.chdir( part_dir )
print( "making partialator file" )
# make partialator run file
part_run_file = run_partialator( part_dir, name, stream, pointgroup, model, iterations, cell, shells, part_h_res, adu )
# submit job
job_id = submit_job( part_run_file, reservation )
print(f"job submitted: {0}".format( job_id ) )
submitted_job_ids.add( job_id )
# use progress bar to track job completion
time.sleep(10)
wait_for_jobs(submitted_job_ids, 1 )
print("slurm processing done")
# stats files names
cc_dat = "cc.dat"
ccstar_dat = "ccstar.dat"
mult_dat = "mult.dat"
rsplit_dat = "Rsplit.dat"
wilson_dat = "wilson.dat"
# make summary data table
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat )
logger.info( "stats table from .dat file =\n{0}".format( stats_df.to_string() ) )
print_df = stats_df[ [ "1_d", "d", "min",
"max", "nref", "poss",
"comp", "obs", "mult",
"snr", "I", "rsplit", "cc", "ccstar" ] ]
print_df.to_csv( "summary_table.csv", sep="\t", index=False )
# calculate cc metrics
cc_cut, cc_tanh = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
ccstar_cut, ccstar_tanh = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 )
print( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) )
print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
logger.info( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) )
logger.info( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
# show plots
summary_fig( stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut )
# move back to top dir
os.chdir( cwd )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-n",
"--name",
help="name of partialator run, also name of folder where data will be processed.",
type=str,
required=True
)
parser.add_argument(
"-s",
"--stream_file",
help="path to stream file",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-p",
"--pointgroup",
help="pointgroup used by CrystFEL for partialator run",
type=str,
required=True
)
parser.add_argument(
"-m",
"--model",
help="model used partialator, e.g., unity or xsphere. Default = unity.",
type=str,
default="unity"
)
parser.add_argument(
"-i",
"--iterations",
help="number of iterations used for partialator run. Default = 1.",
type=int,
default=1
)
parser.add_argument(
"-c",
"--cell_file",
help="path to CrystFEL cell file for partialator.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-b",
"--bins",
help="number of resolution bins to use. Should be more than 20. Default = 20.",
type=int,
default=20
)
parser.add_argument(
"-r",
"--resolution",
help="high res limit - need something here. Default set to 1.3.",
type=float,
default=1.3
)
parser.add_argument(
"-a",
"--max_adu",
help="maximum detector counts to allow. Default is 12000.",
type=int,
default=12000
)
parser.add_argument(
"-R",
"--reservation",
help="reservation name for ra cluster. Usually along the lines of P11111_2024-12-10",
type=str,
default=None
)
parser.add_argument(
"-d",
"--debug",
help="output debug to terminal.",
type=bool,
default=False
)
args = parser.parse_args()
# set loguru
if not args.debug:
logger.remove()
logfile = "{0}.log".format( args.name )
logger.add( logfile, format="{message}", level="INFO")
# run main
cwd = os.getcwd()
print( "top working directory = {0}".format( cwd ) )
main( cwd, args.name, args.stream_file, args.pointgroup, args.model, args.iterations, args.cell_file, args.bins, args.resolution, args.max_adu, args.reservation )