Files
crystfel_tools/reduction_tools/crystfel_split.py
2025-01-26 21:44:34 +01:00

420 lines
15 KiB
Python

#!/usr/bin/python
# author J.Beale
"""
# aim
to process a batch of data very fast by splitting it into a number of chunks and submitting
these jobs separately to the cluster
# usage
python crystfel_split.py -l <path-to-list-file>
-k <chunk-size>
-g <path-to-geom-file>
-c <path-to-cell-file>
-n <name-of-job>
-p photons or
# crystfel parameter may need some editing in the function - write_crystfel_run
# output
a series of stream files from crystfel in the current working directory
a concatenated .stream file in cwd
a log file with .geom and evalation of indexing, cell etc
"""
# modules
import argparse
import pandas as pd
import subprocess
import os, errno
import time
from tqdm import tqdm
import regex as re
import numpy as np
from loguru import logger
def scrub_cells( line ):
# get uc values from stream file
# example - Cell parameters 7.71784 7.78870 3.75250 nm, 90.19135 90.77553 90.19243 deg
pattern = r"Cell\sparameters\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\snm,\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\sdeg"
a = re.search( pattern, line ).group(1)
b = re.search( pattern, line ).group(2)
c = re.search( pattern, line ).group(3)
alpha = re.search( pattern, line ).group(4)
beta = re.search( pattern, line ).group(5)
gamma = re.search( pattern, line ).group(6)
return [ a, b, c, alpha, beta, gamma ]
def scrub_res( stream ):
# get diffraction limit
# example - diffraction_resolution_limit = 4.07 nm^-1 or 2.46 A
pattern = r"diffraction_resolution_limit\s=\s\d\.\d+\snm\^-1\sor\s(\d+\.\d+)\sA"
res = re.search( pattern, stream ).group(1)
return res
def scrub_obs( stream ):
# get number of reflections
# example - num_reflections = 308
pattern = r"num_reflections\s=\s(\d+)"
obs = re.search( pattern, stream ).group(1)
return obs
def calculate_stats( stream_pwd ):
chunks = 0
xtals = 0
cells = []
obs_list = []
res_list = []
print( "scrubing data" )
# open stream file
with open( stream_pwd ) as stream:
for line in stream:
# count chunks
if line.startswith( "----- Begin chunk -----" ):
chunks = chunks + 1
# get cell
if line.startswith( "Cell parameters" ):
cell = scrub_cells( line )
cells.append( cell )
xtals = xtals + 1
# get res
if line.startswith( "diffraction_resolution_limit" ):
res = scrub_res( line )
res_list.append( res )
# get obs
if line.startswith( "num_reflections" ):
obs = scrub_obs( line )
obs_list.append( obs )
if chunks % 1000 == 0:
print( "scrubbed {0} chunks".format( chunks ), end='\r' )
# res_df
cols = [ "a", "b", "c", "alpha", "beta", "gamma" ]
df = pd.DataFrame( cells, columns=cols )
df[ "resolution" ] = res_list
df[ "obs" ] = obs_list
# convert all to floats
df = df.astype(float)
return df, xtals, chunks
def h5_split( lst, chunk_size ):
# read h5.lst - note - removes // from image column
# scrub file name
lst_name = os.path.basename( lst )
cols = [ "h5", "image" ]
df = pd.read_csv( lst, sep="\s//", engine="python", names=cols )
# re-add // to image columm and drop other columns
df[ "h5_path" ] = df.h5 + " //" + df.image.astype( str )
df = df[ [ "h5_path" ] ]
# split df into a lst
list_df = [df[i:i + chunk_size] for i in range( 0, len(df), chunk_size)]
return list_df
def write_crystfel_run( proc_dir, name, chunk, chunk_lst_file, geom_file, cell_file, threshold ):
"""
crystfel run file - spot-finding and indexing parameters may need some editing
only change from inside the quote ("")
"""
# stream file name
stream_file = "{0}_{1}.stream".format( name, chunk )
# crystfel file name
cryst_run_file = "{0}/{1}_{2}.sh".format( proc_dir, name, chunk )
# write file
run_sh = open( cryst_run_file, "w" )
run_sh.write( "#!/bin/sh\n\n" )
run_sh.write( "module purge\n" )
run_sh.write( "module use MX unstable\n" )
run_sh.write( "module load crystfel/0.10.2-rhel8\n" )
run_sh.write( "indexamajig -i {0} \\\n".format( chunk_lst_file ) )
run_sh.write( " --output={0} \\\n".format( stream_file ) )
run_sh.write( " --geometry={0} \\\n".format( geom_file ) )
run_sh.write( " --pdb={0} \\\n".format( cell_file ) )
run_sh.write( " --push-res=0.5 \\\n" )
run_sh.write( " --indexing=xgandalf-latt-cell \\\n" )
run_sh.write( " --peaks=peakfinder8 \\\n" )
run_sh.write( " --integration=rings-nocen-nograd \\\n" )
run_sh.write( " --tolerance=10.0,10.0,10.0,2,3,2 \\\n" )
run_sh.write( " --threshold={0} \\\n".format( threshold ) )
run_sh.write( " --min-snr=5 \\\n" )
run_sh.write( " --int-radius=4,6,7 \\\n" )
run_sh.write( " -j 32 \\\n" )
run_sh.write( " --multi \\\n" )
run_sh.write( " --check-peaks \\\n" )
run_sh.write( " --retry \\\n" )
run_sh.write( " --max-res=3000 \\\n" )
run_sh.write( " --min-pix-count=2 \\\n" )
run_sh.write( " --local-bg-radius=4 \\\n" )
run_sh.write( " --min-res=85" )
run_sh.close()
# make file executable
subprocess.call( [ "chmod", "+x", "{0}".format( cryst_run_file ) ] )
# return crystfel file name
return cryst_run_file, stream_file
def make_process_dir( proc_dir ):
# make process directory
try:
os.makedirs( proc_dir )
except OSError as e:
if e.errno != errno.EEXIST:
logger.debug( "making directory error" )
raise
def submit_job( job_file, reservation ):
# submit the job
if reservation:
print( "using a ra beamtime reservation = {0}".format( reservation ) )
logger.info( "using ra reservation to process data = {0}".format( reservation ) )
submit_cmd = [ "sbatch", "-p", "hour", "--reservation={0}".format( reservation ), "--cpus-per-task=32", "--" , job_file ]
else:
submit_cmd = [ "sbatch", "-p", "hour", "--cpus-per-task=32", "--" , job_file ]
logger.info( "using slurm command = {0}".format( submit_cmd ) )
try:
job_output = subprocess.check_output( submit_cmd )
logger.info( "submited job = {0}".format( job_output ) )
except subprocess.CalledProcessError as e:
print( "please give the correct ra reservation or remove the -v from the arguements" )
exit()
# scrub job id from - example Submitted batch job 742403
pattern = r"Submitted batch job (\d+)"
job_id = re.search( pattern, job_output.decode().strip() ).group(1)
return int( job_id )
def wait_for_jobs( job_ids, total_jobs ):
with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar:
while job_ids:
completed_jobs = set()
for job_id in job_ids:
status_cmd = [ "squeue", "-h", "-j", str( job_id ) ]
status = subprocess.check_output(status_cmd)
if not status:
completed_jobs.add(job_id)
pbar.update(1)
job_ids.difference_update(completed_jobs)
time.sleep(2)
def run_splits( list_df, cwd, name, geom_file, cell_file, threshold, reservation ):
# set chunk counter
chunk = 0
# submitted job set
submitted_job_ids = set()
# stream file list
stream_lst = []
for chunk_lst in list_df:
logger.info( "chunk {0} = {1} images".format( chunk, len( chunk_lst ) ) )
# define process directory
proc_dir = "{0}/{1}/{1}_{2}".format( cwd, name, chunk )
# make process directory
make_process_dir(proc_dir)
# move to process directory
os.chdir( proc_dir )
# write list to file
chunk_lst_file = "{0}/{1}_{2}.lst".format( proc_dir, name, chunk )
chunk_lst.to_csv( chunk_lst_file, index=False, header=False )
# write crystfel file and append path to list
cryst_run_file, stream_file = write_crystfel_run( proc_dir, name, chunk, chunk_lst_file, geom_file, cell_file, threshold )
stream_lst.append( "{0}/{1}".format( proc_dir, stream_file ) )
# submit jobs
job_id = submit_job( cryst_run_file, reservation )
submitted_job_ids.add( job_id )
# increase chunk counter
chunk = chunk +1
# move back to top dir
os.chdir( cwd )
return submitted_job_ids, chunk, stream_lst
def main( cwd, name, lst, chunk_size, geom_file, cell_file, threshold, reservation ):
print( "reading SwissFEL lst file" )
print( "creating {0} image chunks of lst".format( chunk_size ) )
list_df = h5_split( lst, chunk_size )
print( "DONE" )
# run crystfel runs on individual splits
print( "submitting jobs to cluster" )
submitted_job_ids, chunk, stream_lst = run_splits( list_df, cwd, name, geom_file, cell_file, threshold, reservation )
# monitor progress of jobs
time.sleep( 30 )
wait_for_jobs( submitted_job_ids, chunk )
print( "done" )
# make composite .stream file
output_file = "{0}.stream".format( name )
print( "concatenating .streams from separate runs." )
try:
# Open the output file in 'append' mode
with open(output_file, "a") as output:
for file_name in stream_lst:
try:
with open(file_name, "r") as input_file:
# Read the contents of the input file and append to the output file
output.write(input_file.read())
except FileNotFoundError:
logger.debug(f"File {file_name} not found. Skipping.")
except IOError as e:
logger.debug(f"An error occurred while appending files: {e}")
print( "done" )
df, xtals, chunks = calculate_stats( output_file )
# stats
index_rate = round( xtals/chunks*100, 2 )
mean_res, std_res = round( df.resolution.mean(), 2 ), round( df.resolution.std(), 2 )
median_res = df.resolution.median()
mean_obs, std_obs = round( df.obs.mean(), 2 ), round( df.obs.std(), 2)
mean_a, std_a = round( df.a.mean()*10, 2 ), round( df.a.std()*10, 2 )
mean_b, std_b = round( df.b.mean()*10, 2 ), round( df.b.std()*10, 2 )
mean_c, std_c = round( df.c.mean()*10, 2 ), round( df.c.std()*10, 2 )
mean_alpha, std_alpha = round( df.alpha.mean(), 2 ), round( df.alpha.std(), 2 )
mean_beta, std_beta = round(df.beta.mean(), 2 ), round( df.beta.std(), 2 )
mean_gamma, std_gamma = round( df.gamma.mean(), 2 ), round( df.gamma.std(), 2 )
logger.info( "images = {0}".format( chunks ) )
logger.info( "crystals = {0}".format( xtals ) )
logger.info( "indexing rate = {0} %".format( index_rate ) )
logger.info( "mean resolution = {0} +/- {1} A".format( mean_res, std_res ) )
logger.info( "median resolution = {0} A".format( median_res ) )
logger.info( "mean observations = {0} +/- {1}".format( mean_obs, std_obs ) )
logger.info( "mean a = {0} +/- {1} A".format( mean_a, std_a ) )
logger.info( "mean b = {0} +/- {1} A".format( mean_b, std_b ) )
logger.info( "mean c = {0} +/- {1} A".format( mean_c, std_c ) )
logger.info( "mean alpha = {0} +/- {1} deg".format( mean_alpha, std_alpha ) )
logger.info( "mean beta = {0} +/- {1} deg".format( mean_beta, std_beta ) )
logger.info( "mean gamma = {0} +/- {1} deg".format( mean_gamma, std_gamma ) )
print( "printing stats" )
print( "images = {0}".format( chunks ) )
print( "crystals = {0}".format( xtals ) )
print( "indexing rate = {0} %".format( index_rate ) )
print( "mean resolution = {0} +/- {1} A".format( mean_res, std_res ) )
print( "median resolution = {0} A".format( median_res ) )
print( "mean observations = {0} +/- {1}".format( mean_obs, std_obs ) )
print( "mean a = {0} +/- {1} A".format( mean_a, std_a ) )
print( "mean b = {0} +/- {1} A".format( mean_b, std_b ) )
print( "mean c = {0} +/- {1} A".format( mean_c, std_c ) )
print( "mean alpha = {0} +/- {1} deg".format( mean_alpha, std_alpha ) )
print( "mean beta = {0} +/- {1} deg".format( mean_beta, std_beta ) )
print( "mean gamma = {0} +/- {1} deg".format( mean_gamma, std_gamma ) )
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-l",
"--lst_file",
help="file from SwissFEL output to be processed quickly. Requried.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-k",
"--chunk_size",
help="how big should each chunk be? - the bigger the chunk, the fewer jobs, the slower it will be. Default = 500.",
type=int,
default=500
)
parser.add_argument(
"-g",
"--geom_file",
help="path to geom file to be used in the refinement. Required.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-c",
"--cell_file",
help="path to cell file of the crystals used in the refinement. Required.",
type=os.path.abspath,
required=True
)
parser.add_argument(
"-n",
"--job_name",
help="the name of the job to be done. Default = 'split_###'",
type=str,
default="split"
)
parser.add_argument(
"-v",
"--reservation",
help="reservation name for ra cluster. Usually along the lines of P11111_2024-12-10",
type=str,
default=None
)
parser.add_argument(
"-p",
"--photons_or_energy",
help="determines the threshold to use for CrystFEL. Photons counts have always been used in Cristallina and are now used on Alvra from 01.11.2024. Please use 'energy' for Alvra before this.",
type=str,
default="photons"
)
parser.add_argument(
"-d",
"--debug",
help="output debug to terminal.",
type=bool,
default=False
)
args = parser.parse_args()
# set current working directory
cwd = os.getcwd()
# set loguru
if not args.debug:
logger.remove()
logfile = "{0}.log".format( args.job_name )
logger.add( logfile, format="{message}", level="INFO")
# log geometry file
geom = open( args.geom_file, "r" ).read()
logger.info( geom )
# set threshold based on detector
if args.photons_or_energy == "energy":
threshold = 3000
elif args.photons_or_energy == "photons":
threshold = 15
main( cwd, args.job_name, args.lst_file, args.chunk_size, args.geom_file, args.cell_file, threshold, args.reservation )