Files
crystfel_tools/reduction_tools/partialator_summary.py
2025-03-11 08:30:23 +01:00

384 lines
13 KiB
Python

#!/usr/bin/python
# author J.Beale
"""
# aim
to merge .stream files and calculate statistics
# usage
python partialator.py -s <path-to-stream-file>
-n name (name of job - default = partialator)
-p pointgroup
-m model (unity or xsphere - default is unity)
-i iterations - number of iterations in partialator
-c <path-to-cell-file>
-b number of resolution bins - must be > 20
-r high-res limt. Needs a default. Default set to 1.3
-a max-adu. Default = 12000
-v ra reservation name if available
# output
- scaled/merged files
- an mtz file
- useful plots
- useful summerized .dat files
- log file of output
"""
# modules
from sys import exit
import pandas as pd
import numpy as np
import subprocess
import os, errno
import time
import argparse
from tqdm import tqdm
import regex as re
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import warnings
warnings.filterwarnings( "ignore", category=RuntimeWarning )
from loguru import logger
def submit_job( job_file ):
submit_cmd = [ "sbatch", "--" , job_file ]
logger.info( "using slurm command = {0}".format( submit_cmd ) )
job_output = subprocess.check_output( submit_cmd )
logger.info( "submited job = {0}".format( job_output ) )
# scrub job id from - example Submitted batch job 742403
pattern = r"Submitted batch job (\d+)"
job_id = re.search( pattern, job_output.decode().strip() ).group(1)
return int( job_id )
def wait_for_jobs( job_ids, total_jobs ):
with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar:
while job_ids:
completed_jobs = set()
for job_id in job_ids:
status_cmd = [ "squeue", "-h", "-j", str( job_id ) ]
status = subprocess.check_output( status_cmd )
if not status:
completed_jobs.add( job_id )
pbar.update( 1 )
job_ids.difference_update( completed_jobs )
time.sleep( 2 )
def run_compare_check( proc_dir, name, cell, shells, part_h_res ):
# check file name
check_run_file = "{0}/check_{1}.sh".format( proc_dir, name )
# write file
check_sh = open( check_run_file, "w" )
check_sh.write( "#!/bin/sh\n\n" )
check_sh.write( "module purge\n" )
check_sh.write( "module use MX unstable\n" )
check_sh.write( "module load crystfel/0.10.2-rhel8\n" )
check_sh.write( "check_hkl --shell-file=mult.dat *.hkl -p {0} --nshells={1} --highres={2} &> check_hkl.log\n".format( cell, shells, part_h_res ) )
check_sh.write( "check_hkl --ltest --ignore-negs --shell-file=ltest.dat *.hkl -p {0} --nshells={1} --highres={2} &> ltest.log\n".format( cell, shells, part_h_res ) )
check_sh.write( "check_hkl --wilson --shell-file=wilson.dat *.hkl -p {0} --nshells={1} --highres={2} &> wilson.log\n".format( cell, shells, part_h_res ) )
check_sh.write( "compare_hkl --fom=Rsplit --shell-file=Rsplit.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> Rsplit.log\n".format( cell, shells, part_h_res ) )
check_sh.write( "compare_hkl --fom=cc --shell-file=cc.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> cc.log\n".format( cell, shells, part_h_res ) )
check_sh.write( "compare_hkl --fom=ccstar --shell-file=ccstar.dat *.hkl1 *hkl2 -p {0} --nshells={1} --highres={2} &> ccstar.log\n".format( cell, shells, part_h_res ) )
check_sh.close()
# make file executable
subprocess.call( [ "chmod", "+x", "{0}".format( check_run_file ) ] )
# add check script to log
check_input = open( check_run_file, "r" )
logger.info( "check input file =\n{0}".format( check_input.read() ) )
check_input.close()
# return check file name
return check_run_file
def summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat ):
# read all files into pd
# function to sort out different column names
def read_dat( dat, var ):
# different columns names of each dat file
if var == "cc":
cols = [ "d(nm)", "cc", "nref", "d", "min", "max" ]
elif var == "ccstar":
cols = [ "1(nm)", "ccstar", "nref", "d", "min", "max" ]
elif var == "mult":
cols = [ "d(nm)", "nref", "poss", "comp", "obs",
"mult", "snr", "I", "d", "min", "max" ]
elif var == "rsplit":
cols = [ "d(nm)", "rsplit", "nref", "d", "min", "max" ]
elif var == "wilson":
cols = [ "bin", "s2", "d", "lnI", "nref" ]
df = pd.read_csv( dat, names=cols, skiprows=1, sep="\s+" )
return df
# make df
cc_df = read_dat( cc_dat, "cc" )
ccstar_df = read_dat( ccstar_dat, "ccstar" )
mult_df = read_dat( mult_dat, "mult" )
rsplit_df = read_dat( rsplit_dat, "rsplit" )
wilson_df = read_dat( wilson_dat, "wilson" )
# remove unwanted cols
cc_df = cc_df[ [ "cc" ] ]
ccstar_df = ccstar_df[ [ "ccstar" ] ]
rsplit_df = rsplit_df[ [ "rsplit" ] ]
wilson_df = wilson_df[ [ "lnI" ] ]
# merge dfs
stats_df = pd.concat( [ mult_df, cc_df, ccstar_df, rsplit_df, wilson_df ], axis=1, join="inner" )
# make 1/d, 1/d^2 column
stats_df[ "1_d" ] = 1 / stats_df.d
stats_df[ "1_d2" ] = 1 / stats_df.d**2
# change nan to 0
stats_df = stats_df.fillna(0)
return stats_df
def get_metric( d2_series, cc_series, cut_off ):
# Define the tanh function from scitbx
def tanh(x, r, s0):
z = (x - s0)/r
return 0.5 * ( 1 - np.tanh(z) )
def arctanh( y, r, s0 ):
return r * np.arctanh( 1 - 2*y ) + s0
# Fit the tanh to the data
params, covariance = curve_fit( tanh, d2_series, cc_series )
# Extract the fitted parameters
r, s0 = params
# calculate cut-off point
cc_stat = arctanh( cut_off, r, s0 )
# covert back from 1/d2 to d
cc_stat = np.sqrt( ( 1 / cc_stat ) )
# get curve for plotting
cc_tanh = tanh( d2_series, r, s0 )
return round( cc_stat, 3 ), cc_tanh
def get_overall_cc():
# open cc log file
cc_log_file = open( "cc.log" )
cc_log = cc_log_file.read()
# regex example = Overall CC = 0.5970865
overcc_pattern = r"Overall\sCC\s=\s(\d\.\d+)"
try:
overcc = re.search( overcc_pattern, cc_log ).group(1)
except AttributeError as e:
overcc = np.nan
return overcc
def get_overall_rsplit():
# open rsplit log file
rsplit_log_file = open( "Rsplit.log" )
rsplit_log = rsplit_log_file.read()
# regex example = Overall Rsplit = 54.58 %
overrsplit_pattern = r"Overall\sRsplit\s=\s(\d+\.\d+)"
try:
overrsplit = re.search( overrsplit_pattern, rsplit_log ).group(1)
except AttributeError as e:
overrsplit = np.nan
return overrsplit
def get_b():
# open rsplit log file
wilson_log_file = open( "wilson.log" )
wilson_log = wilson_log_file.read()
# regex example = B = 41.63 A^2
b_factor_pattern = r"B\s=\s(\d+\.\d+)\sA"
try:
b_factor = re.search( b_factor_pattern, wilson_log ).group(1)
except AttributeError as e:
b_factor = np.nan
return b_factor
def summary_fig( name, stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut ):
def dto1_d( x ):
return 1/x
# plot results
cc_fig, axs = plt.subplots(2, 2)
cc_fig.suptitle( "cc and cc* vs resolution" )
# cc plot
color = "tab:red"
axs[0,0].set_xlabel( "1/d (1/A)" )
axs[0,0].set_ylabel( "CC", color=color )
axs[0,0].set_ylim( 0, 1 )
axs[0,0].axhline( y = 0.3, color="black", linestyle = "dashed" )
# plot cc
axs[0,0].plot( stats_df[ "1_d" ], stats_df.cc, color=color )
# plot fit
axs[0,0].plot( stats_df[ "1_d" ], cc_tanh, color="tab:grey", linestyle = "dashed" )
sax1 = axs[0,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) )
sax1.set_xlabel('d (A)')
axs[0,0].tick_params( axis="y", labelcolor=color )
axs[0,0].text( 0.1, 0.1, "CC0.5 @ 0.3 = {0}".format( cc_cut ), fontsize = 8 )
# cc* plot
color = "tab:blue"
axs[0,1].set_xlabel( "1/d (1/A)" )
axs[0,1].set_ylabel( "CC*", color=color )
axs[0,1].set_ylim( 0, 1 )
axs[0,1].axhline( y = 0.7, color="black", linestyle = "dashed" )
axs[0,1].plot( stats_df[ "1_d" ], stats_df.ccstar, color=color )
# plot fit
axs[0,1].plot( stats_df[ "1_d" ], ccstar_tanh, color="tab:grey", linestyle = "dashed" )
sax2 = axs[0,1].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) )
sax2.set_xlabel('d (A)')
axs[0,1].tick_params( axis='y', labelcolor=color )
axs[0,1].text( 0.1, 0.1, "CC* @ 0.7 = {0}".format( ccstar_cut ) , fontsize = 8 )
# rsplit plot
color = "tab:green"
axs[1,0].set_xlabel( "1/d (1/A)" )
axs[1,0].set_ylabel( "Rsplit", color=color )
axs[1,0].plot( stats_df[ "1_d" ], stats_df.rsplit, color=color )
sax3 = axs[1,0].secondary_xaxis( 'top', functions=( dto1_d, dto1_d ) )
sax3.set_xlabel( 'd (A)' )
axs[1,0].tick_params( axis='y', labelcolor=color )
# wilson plot
color = "tab:purple"
axs[1,1].set_xlabel( "1/d**2 (1/A**2)" )
axs[1,1].set_ylabel( "lnI", color=color )
axs[1,1].plot( stats_df[ "1_d2" ], stats_df.lnI, color=color )
axs[1,1].tick_params( axis='y', labelcolor=color )
# save figure
plt.tight_layout()
plt.savefig( "{0}_plots.png".format( name ) )
def main( cwd, name, cell, shells, part_h_res ):
# submitted job set
submitted_job_ids = set()
# now run the check and compare scripts
print( "running check/compare" )
check_run_file = run_compare_check( cwd, name, cell, shells, part_h_res )
check_job_id = submit_job( check_run_file )
print( f"job submitted: {0}".format( check_job_id ) )
submitted_job_ids.add( check_job_id )
time.sleep(10)
wait_for_jobs( submitted_job_ids, 1 )
print( "done" )
# stats files names
cc_dat = "cc.dat"
ccstar_dat = "ccstar.dat"
mult_dat = "mult.dat"
rsplit_dat = "Rsplit.dat"
wilson_dat = "wilson.dat"
# make summary data table
stats_df = summary_stats( cc_dat, ccstar_dat, mult_dat, rsplit_dat, wilson_dat )
logger.info( "stats table from .dat file =\n{0}".format( stats_df.to_string() ) )
print_df = stats_df[ [ "1_d", "d", "min",
"max", "nref", "poss",
"comp", "obs", "mult",
"snr", "I", "rsplit", "cc", "ccstar" ] ]
print_df.to_csv( "{0}_summary_table.csv".format( name ), sep="\t", index=False )
# calculate cc metrics
cc_cut, cc_tanh = get_metric( stats_df[ "1_d2" ], stats_df.cc, 0.3 )
ccstar_cut, ccstar_tanh = get_metric( stats_df[ "1_d2" ], stats_df.ccstar, 0.7 )
print( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) )
print( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
logger.info( "resolution at CC0.5 at 0.3 = {0}".format( cc_cut ) )
logger.info( "resolution at CC* at 0.7 = {0}".format( ccstar_cut ) )
# scrub other metrics
overcc = get_overall_cc()
overrsplit = get_overall_rsplit()
b_factor = get_b()
logger.info( "overall CC0.5 = {0}".format( overcc ) )
logger.info( "overall Rsplit = {0}".format( overrsplit ) )
logger.info( "overall B = {0}".format( b_factor ) )
# show plots
summary_fig( name, stats_df, cc_tanh, ccstar_tanh, cc_cut, ccstar_cut )
# move back to top dir
os.chdir( cwd )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-n",
"--name",
help="name of check.",
type=str,
required=True
)
parser.add_argument(
"-c",
"--cell_file",
help="path to CrystFEL cell file for partialator.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-b",
"--bins",
help="number of resolution bins to use. Should be more than 20. Default = 20.",
type=int,
default=20
)
parser.add_argument(
"-r",
"--resolution",
help="high res limit - need something here. Default set to 1.3.",
type=float,
default=1.3
)
parser.add_argument(
"-d",
"--debug",
help="output debug to terminal.",
type=bool,
default=False
)
args = parser.parse_args()
# set loguru
if not args.debug:
logger.remove()
logfile = "{0}.log".format( args.name )
logger.add( logfile, format="{message}", level="INFO")
# run main
cwd = os.getcwd()
print( "top working directory = {0}".format( cwd ) )
main( cwd, args.name, args.cell_file, args.bins, args.resolution )