updated bool in argparse

This commit is contained in:
Beale John Henry
2025-01-16 11:54:46 +01:00
parent 287780cc2c
commit 58db2c33da

View File

@@ -27,12 +27,99 @@ a series of stream files from crystfel in the current working directory
# modules
import pandas as pd
import numpy as np
import subprocess
import os, errno
import time
import argparse
from tqdm import tqdm
import regex as re
from loguru import logger
def count_chunks( stream ):
# get number of chunks
# example - ----- Begin chunk -----
# count them
try:
pattern = r"-----\sBegin\schunk\s-----"
chunks = re.findall( pattern, stream )
if AttributeError:
return len( chunks )
except AttributeError:
logger.debug( "count_chunks error" )
return np.nan
def scrub_cells( stream ):
# get uc values from stream file
# example - Cell parameters 7.71784 7.78870 3.75250 nm, 90.19135 90.77553 90.19243 deg
# scrub clen and return - else nan
try:
pattern = r"Cell\sparameters\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\snm,\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\sdeg"
cell_lst = re.findall( pattern, stream )
xtals = len( cell_lst )
if AttributeError:
return cell_lst, xtals
except AttributeError:
logger.debug( "scrub_cells error" )
return np.nan
def scrub_res( stream ):
# get diffraction limit
# example - diffraction_resolution_limit = 4.07 nm^-1 or 2.46 A
# scrub res_lst or return np.nan
try:
pattern = r"diffraction_resolution_limit\s=\s\d+\.\d+\snm\^-1\sor\s(\d+\.\d+)\sA"
res_lst = re.findall( pattern, stream )
if AttributeError:
return res_lst
except AttributeError:
logger.debug( "scrub_res error" )
return np.nan
def scrub_obs( stream ):
# get number of reflections
# example - num_reflections = 308
# scrub reflections or return np.nan
try:
pattern = r"num_reflections\s=\s(\d+)"
obs_lst = re.findall( pattern, stream )
if AttributeError:
return obs_lst
except AttributeError:
logger.debug( "scrub_obs error" )
return np.nan
def calculate_stats( stream_pwd ):
# open stream file
stream = open( stream_pwd, "r" ).read()
# get total number chunks
chunks = count_chunks( stream )
# get list of cells
cell_lst, xtals = scrub_cells( stream )
# get list of cells
res_lst = scrub_res( stream )
# get list of cells
obs_lst = scrub_obs( stream )
# res_df
cols = [ "a", "b", "c", "alpha", "beta", "gamma" ]
df = pd.DataFrame( cell_lst, columns=cols )
df[ "resolution" ] = res_lst
df[ "obs" ] = obs_lst
# convert all to floats
df = df.astype(float)
return df, xtals, chunks
def h5_split( lst, chunk_size ):
@@ -53,7 +140,8 @@ def h5_split( lst, chunk_size ):
return list_df
def write_crystfel_run( proc_dir, name, chunk, chunk_lst_file,
geom_file, cell_file, threshold, min_snr,
geom_file, cell_file, indexer, peakfinder,
integrator, tolerance, threshold, min_snr,
int_rad, multi, retry, min_pix, bg_rad, min_res ):
# stream file name
@@ -71,10 +159,11 @@ def write_crystfel_run( proc_dir, name, chunk, chunk_lst_file,
run_sh.write( " --output={0} \\\n".format( stream_file ) )
run_sh.write( " --geometry={0} \\\n".format( geom_file ) )
run_sh.write( " --pdb={0} \\\n".format( cell_file ) )
run_sh.write( " --indexing=xgandalf-latt-cell \\\n" )
run_sh.write( " --peaks=peakfinder8 \\\n" )
run_sh.write( " --integration=rings-grad \\\n" )
run_sh.write( " --tolerance=10.0,10.0,10.0,2,3,2 \\\n" )
run_sh.write( " --indexing={0} \\\n".format( indexer ) )
run_sh.write( " --peaks={0} \\\n".format( peakfinder ) )
run_sh.write( " --integration={0} \\\n".format( integrator ) )
run_sh.write( " --tolerance={0},{1},{2},{3},{4},{5} \\\n".format( tolerance[0], tolerance[1], tolerance[2],
tolerance[3], tolerance[4], tolerance[5] ) )
run_sh.write( " --threshold={0} \\\n".format( threshold ) )
run_sh.write( " --min-snr={0} \\\n".format( min_snr ) )
run_sh.write( " --int-radius={0},{1},{2} \\\n".format( int_rad[0], int_rad[1], int_rad[2] ) )
@@ -100,42 +189,51 @@ def make_process_dir( proc_dir ):
os.makedirs( proc_dir )
except OSError as e:
if e.errno != errno.EEXIST:
logger.debug( "making directory error" )
raise
def submit_job( job_file ):
def submit_job( job_file, reservation ):
# submit the job
submit_cmd = ["sbatch", "--cpus-per-task=32", "--" ,job_file]
job_output = subprocess.check_output(submit_cmd)
if reservation:
print( "using a ra beamtime reservation = {0}".format( reservation ) )
logger.info( "using ra reservation to process data = {0}".format( reservation ) )
submit_cmd = [ "sbatch", "-p", "hour", "--reservation={0}".format( reservation ), "--cpus-per-task=32", "--" , job_file ]
else:
submit_cmd = [ "sbatch", "-p", "hour", "--cpus-per-task=32", "--" , job_file ]
logger.info( "using slurm command = {0}".format( submit_cmd ) )
try:
job_output = subprocess.check_output( submit_cmd )
logger.info( "submited job = {0}".format( job_output ) )
except subprocess.CalledProcessError as e:
print( "please give the correct ra reservation or remove the -v from the arguements" )
exit()
# scrub job id from - example Submitted batch job 742403
pattern = r"Submitted batch job (\d+)"
job_id = re.search( pattern, job_output.decode().strip() ).group(1)
return int(job_id)
return int( job_id )
def wait_for_jobs( job_ids, total_jobs ):
with tqdm(total=total_jobs, desc="Jobs Completed", unit="job") as pbar:
with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar:
while job_ids:
completed_jobs = set()
for job_id in job_ids:
status_cmd = ["squeue", "-h", "-j", str(job_id)]
status_cmd = [ "squeue", "-h", "-j", str( job_id ) ]
status = subprocess.check_output(status_cmd)
if not status:
completed_jobs.add(job_id)
pbar.update(1)
job_ids.difference_update(completed_jobs)
time.sleep(10)
time.sleep(2)
def run_splits( cwd, name, lst, chunk_size, geom_file,
cell_file, progress, threshold, min_snr,
int_rad, multi, retry, min_pix ):
print( "reading SwissFEL lst file" )
print( "creating {0} image chunks of lst".format( chunk_size ) )
list_df = h5_split( lst, chunk_size )
print( "DONE" )
def run_splits( list_df, cwd, name, geom_file, cell_file,
indexer, peakfinder, integrator, tolerance, threshold,
min_snr, int_rad, multi, retry, min_pix, bg_rad,
min_res, reservation ):
# set chunk counter
chunk = 0
@@ -146,10 +244,9 @@ def run_splits( cwd, name, lst, chunk_size, geom_file,
# stream file list
stream_lst = []
print( "creating crystfel jobs for individual chunks" )
for chunk_lst in list_df:
print( "chunk {0} = {1} images".format( chunk, len( chunk_lst ) ) )
logger.info( "chunk {0} = {1} images".format( chunk, len( chunk_lst ) ) )
# define process directory
proc_dir = "{0}/{1}/{1}_{2}".format( cwd, name, chunk )
@@ -165,13 +262,13 @@ def run_splits( cwd, name, lst, chunk_size, geom_file,
# write crystfel file and append path to list
cryst_run_file, stream_file = write_crystfel_run( proc_dir, name, chunk, chunk_lst_file,
geom_file, cell_file, threshold, min_snr,
int_rad, multi, retry, min_pix )
geom_file, cell_file, indexer, peakfinder,
integrator, tolerance, threshold, min_snr,
int_rad, multi, retry, min_pix, bg_rad, min_res )
stream_lst.append( "{0}/{1}".format( proc_dir, stream_file ) )
# submit jobs
job_id = submit_job( cryst_run_file )
print(f"Job submitted: { job_id }")
job_id = submit_job( cryst_run_file, reservation )
submitted_job_ids.add( job_id )
# increase chunk counter
@@ -180,15 +277,34 @@ def run_splits( cwd, name, lst, chunk_size, geom_file,
# move back to top dir
os.chdir( cwd )
return submitted_job_ids, chunk, stream_lst
def main( cwd, name, lst, chunk_size, geom_file, cell_file,
indexer, peakfinder, integrator, tolerance, threshold,
min_snr, int_rad, multi, retry, min_pix, bg_rad,
min_res, reservation ):
print( "reading SwissFEL lst file" )
print( "creating {0} image chunks of lst".format( chunk_size ) )
list_df = h5_split( lst, chunk_size )
print( "DONE" )
wait_for_jobs(submitted_job_ids, chunk)
print("slurm processing done")
# run crystfel runs on individual splits
print( "submitting jobs to cluster" )
submitted_job_ids, chunk, stream_lst = run_splits( list_df, cwd, name, geom_file, cell_file,
indexer, peakfinder, integrator, tolerance, threshold,
min_snr, int_rad, multi, retry, min_pix, bg_rad,
min_res, reservation )
# monitor progress of jobs
time.sleep( 30 )
wait_for_jobs( submitted_job_ids, chunk )
print( "done" )
# make composite .stream file
output_file = "{0}.stream".format( name )
print( "comp" )
print( "concatenating .streams from separate runs." )
try:
# Open the output file in 'append' mode
with open(output_file, "a") as output:
@@ -197,110 +313,206 @@ def run_splits( cwd, name, lst, chunk_size, geom_file,
with open(file_name, "r") as input_file:
# Read the contents of the input file and append to the output file
output.write(input_file.read())
print(f"Appended contents from {file_name} to {output_file}")
except FileNotFoundError:
print(f"File {file_name} not found. Skipping.")
logger.debug(f"File {file_name} not found. Skipping.")
except IOError as e:
print(f"An error occurred while appending files: {e}")
logger.debug(f"An error occurred while appending files: {e}")
print( "done" )
print( "DONE" )
df, xtals, chunks = calculate_stats( output_file )
# stats
index_rate = round( xtals/chunks*100, 2 )
mean_res, std_res = round( df.resolution.mean(), 2 ), round( df.resolution.std(), 2 )
median_res = df.resolution.median()
mean_obs, std_obs = round( df.obs.mean(), 2 ), round( df.obs.std(), 2)
mean_a, std_a = round( df.a.mean()*10, 2 ), round( df.a.std()*10, 2 )
mean_b, std_b = round( df.b.mean()*10, 2 ), round( df.b.std()*10, 2 )
mean_c, std_c = round( df.c.mean()*10, 2 ), round( df.c.std()*10, 2 )
mean_alpha, std_alpha = round( df.alpha.mean(), 2 ), round( df.alpha.std(), 2 )
mean_beta, std_beta = round(df.beta.mean(), 2 ), round( df.beta.std(), 2 )
mean_gamma, std_gamma = round( df.gamma.mean(), 2 ), round( df.gamma.std(), 2 )
logger.info( "image = {0}".format( chunks ) )
logger.info( "crystals = {0}".format( xtals ) )
logger.info( "indexing rate = {0} %".format( index_rate ) )
logger.info( "mean resolution = {0} +/- {1} A".format( mean_res, std_res ) )
logger.info( "median resolution = {0} A".format( median_res ) )
logger.info( "mean observations = {0} +/- {1}".format( mean_obs, std_obs ) )
logger.info( "mean a = {0} +/- {1} A".format( mean_a, std_a ) )
logger.info( "mean b = {0} +/- {1} A".format( mean_b, std_b ) )
logger.info( "mean c = {0} +/- {1} A".format( mean_c, std_c ) )
logger.info( "mean alpha = {0} +/- {1} A".format( mean_alpha, std_alpha ) )
logger.info( "mean beta = {0} +/- {1} A".format( mean_beta, std_beta ) )
logger.info( "mean gamma = {0} +/- {1} A".format( mean_gamma, std_gamma ) )
print( "printing stats" )
print( "image = {0}".format( chunks ) )
print( "crystals = {0}".format( xtals ) )
print( "indexing rate = {0} %".format( index_rate ) )
print( "mean resolution = {0} +/- {1} A".format( mean_res, std_res ) )
print( "median resolution = {0} A".format( median_res ) )
print( "mean observations = {0} +/- {1}".format( mean_obs, std_obs ) )
print( "mean a = {0} +/- {1} A".format( mean_a, std_a ) )
print( "mean b = {0} +/- {1} A".format( mean_b, std_b ) )
print( "mean c = {0} +/- {1} A".format( mean_c, std_c ) )
print( "mean alpha = {0} +/- {1} A".format( mean_alpha, std_alpha ) )
print( "mean beta = {0} +/- {1} A".format( mean_beta, std_beta ) )
print( "mean gamma = {0} +/- {1} A".format( mean_gamma, std_gamma ) )
def list_of_ints(arg):
return list(map(int, arg.split(',')))
def list_of_floats(arg):
return list(map(float, arg.split(',')))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-l",
"--lst_file",
help="file from SwissFEL output to be processed quickly",
type=os.path.abspath
)
parser.add_argument(
"-k",
"--chunk_size",
help="how big should each chunk be? - the bigger the chunk, the fewer jobs, the slower it will be",
type=int,
default=1000
)
parser.add_argument(
"-g",
"--geom_file",
help="path to geom file to be used in the refinement",
type=os.path.abspath
)
parser.add_argument(
"-c",
"--cell_file",
help="path to cell file of the crystals used in the refinement",
type=os.path.abspath
)
parser.add_argument(
"-n",
"--job_name",
help="the name of the job to be done",
help="the name of the job to be done. Default = split",
type=str,
default="split"
)
parser.add_argument(
"-l",
"--lst_file",
help="file from SwissFEL output to be processed quickly. Requried.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-k",
"--chunk_size",
help="how big should each image split be? Default = 500. Fewer will be faster.",
type=int,
default=500
)
parser.add_argument(
"-g",
"--geom_file",
help="path to geom file to be used in the refinement. Requried.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-c",
"--cell_file",
help="path to cell file of the crystals used in the refinement. Requried.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-x",
"--indexer",
help="indexer to use. Default = xgandalf-latt-cell",
type=str,
default="xgandalf-latt-cell"
)
parser.add_argument(
"-f",
"--peakfinder",
help="peakfinder to use. Default = peakfinder8",
type=str,
default="peakfinder8"
)
parser.add_argument(
"-a",
"--integrator",
help="integrator to use. Default = rings-nocen-nograd",
type=str,
default="rings-nocen-nograd"
)
parser.add_argument(
"-y",
"--tolerance",
help="tolerance to use. Default = 10.0,10.0,10.0,2.0,3.0,2.0",
type=list_of_floats,
default=[10.0,10.0,10.0,2.0,3.0,2.0]
)
parser.add_argument(
"-t",
"--threshold",
help="threshold for crystfel run - peaks must be above this to be found",
help="peaks must be above this to be found during spot-finding. Default = 20",
type=int,
default=10
default=20
)
parser.add_argument(
"-s",
"--min_snr",
help="min-snr for crystfel run - peaks must to above this to be counted",
help="peaks must to above this to be counted. Default = 5.",
type=int,
default=5
)
parser.add_argument(
"-i",
"--int_radius",
help="int_rad for crystfel run - peaks must to above this to be counted",
help="integration ring radii. Default = 2,3,5 = 2 for spot and then 3 and 5 to calculate background.",
type=list_of_ints,
default=[3,5,9]
default=[2,3,5]
)
parser.add_argument(
"-m",
"--multi",
help="multi crystfel flag, do you wnat to look for multiple lattices",
help="do you wnat to look for multiple lattices. Default = True",
type=bool,
default=False
default=True
)
parser.add_argument(
"-r",
"--retry",
help="retry crystfel flag, do you want to retry failed indexing patterns",
help="do you want to retry failed indexing patterns. Default = False",
type=bool,
default=False
)
parser.add_argument(
"-x",
"-p",
"--min_pix",
help="min-pix-count for crystfel runs, minimum number of pixels a spot should contain in peak finding",
help="minimum number of pixels a spot should contain in peak finding.Default = 2",
type=int,
default=2
)
parser.add_argument(
"-b",
"--bg_rad",
help="crystfel background radius flag, radius (in pixels) used for the estimation of the local background",
help="radius (in pixels) used for the estimation of the local background. Default = 4",
type=int,
default=2
default=4
)
parser.add_argument(
"-q",
"--min-res",
help="m",
help="min-res for spot-finding in pixels. Default = 85.",
type=int,
default=2
default=85
)
parser.add_argument(
"-v",
"--reservation",
help="reservation name for ra cluster. Usually along the lines of P11111_2024-12-10",
type=str,
default=None
)
parser.add_argument(
"-d",
"--debug",
help="output debug to terminal.",
type=bool,
default=False
)
args = parser.parse_args()
# run geom converter
cwd = os.getcwd()
# set loguru
if not args.debug:
logger.remove()
logfile = "{0}.log".format( args.job_name )
logger.add( logfile, format="{message}", level="INFO")
# log geometry file
geom = open( args.geom_file, "r" ).read()
logger.info( geom )
if args.multi == True:
multi = "multi"
else:
@@ -309,7 +521,8 @@ if __name__ == "__main__":
retry = "retry"
else:
retry = "no-retry"
run_splits( cwd, args.job_name, args.lst_file, args.chunk_size,
args.geom_file, args.cell_file,
args.threshold, args.min_snr, args.int_radius,
multi, retry, args.min_pix )
main( cwd, args.job_name, args.lst_file, args.chunk_size,
args.geom_file, args.cell_file, args.indexer, args.peakfinder,
args.integrator, args.tolerance, args.threshold,
args.min_snr, args.int_radius, multi, retry, args.min_pix, args.bg_rad,
args.min_res, args.reservation )