#!/usr/bin/python # author J.Beale """ # aim to process a batch of data very fast by splitting it into a number of chunks and submitting these jobs separately to the cluster # usage python crystfel_split.py -l -k -g -c -n -p photons or # crystfel parameter may need some editing in the function - write_crystfel_run # output a series of stream files from crystfel in the current working directory a concatenated .stream file in cwd a log file with .geom and evalation of indexing, cell etc """ # modules import argparse import pandas as pd import subprocess import os, errno import time from tqdm import tqdm import regex as re import numpy as np from loguru import logger def scrub_cells( line ): # get uc values from stream file # example - Cell parameters 7.71784 7.78870 3.75250 nm, 90.19135 90.77553 90.19243 deg pattern = r"Cell\sparameters\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\snm,\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\sdeg" a = re.search( pattern, line ).group(1) b = re.search( pattern, line ).group(2) c = re.search( pattern, line ).group(3) alpha = re.search( pattern, line ).group(4) beta = re.search( pattern, line ).group(5) gamma = re.search( pattern, line ).group(6) return [ a, b, c, alpha, beta, gamma ] def scrub_res( stream ): # get diffraction limit # example - diffraction_resolution_limit = 4.07 nm^-1 or 2.46 A pattern = r"diffraction_resolution_limit\s=\s\d\.\d+\snm\^-1\sor\s(\d+\.\d+)\sA" res = re.search( pattern, stream ).group(1) return res def scrub_obs( stream ): # get number of reflections # example - num_reflections = 308 pattern = r"num_reflections\s=\s(\d+)" obs = re.search( pattern, stream ).group(1) return obs def calculate_stats( stream_pwd ): chunks = 0 xtals = 0 cells = [] obs_list = [] res_list = [] print( "scrubing data" ) # open stream file with open( stream_pwd ) as stream: for line in stream: # count chunks if line.startswith( "----- Begin chunk -----" ): chunks = chunks + 1 # get cell if line.startswith( "Cell parameters" ): cell = scrub_cells( line ) cells.append( cell ) xtals = xtals + 1 # get res if line.startswith( "diffraction_resolution_limit" ): res = scrub_res( line ) res_list.append( res ) # get obs if line.startswith( "num_reflections" ): obs = scrub_obs( line ) obs_list.append( obs ) if chunks % 1000 == 0: print( "scrubbed {0} chunks".format( chunks ), end='\r' ) # res_df cols = [ "a", "b", "c", "alpha", "beta", "gamma" ] df = pd.DataFrame( cells, columns=cols ) df[ "resolution" ] = res_list df[ "obs" ] = obs_list # convert all to floats df = df.astype(float) return df, xtals, chunks def h5_split( lst, chunk_size ): # read h5.lst - note - removes // from image column # scrub file name lst_name = os.path.basename( lst ) cols = [ "h5", "image" ] df = pd.read_csv( lst, sep="\s//", engine="python", names=cols ) # re-add // to image columm and drop other columns df[ "h5_path" ] = df.h5 + " //" + df.image.astype( str ) df = df[ [ "h5_path" ] ] # split df into a lst list_df = [df[i:i + chunk_size] for i in range( 0, len(df), chunk_size)] return list_df def write_crystfel_run( proc_dir, name, chunk, chunk_lst_file, geom_file, cell_file, threshold ): """ crystfel run file - spot-finding and indexing parameters may need some editing only change from inside the quote ("") """ # stream file name stream_file = "{0}_{1}.stream".format( name, chunk ) # crystfel file name cryst_run_file = "{0}/{1}_{2}.sh".format( proc_dir, name, chunk ) # write file run_sh = open( cryst_run_file, "w" ) run_sh.write( "#!/bin/sh\n\n" ) run_sh.write( "module purge\n" ) run_sh.write( "module use MX unstable\n" ) run_sh.write( "module load crystfel/0.10.2-rhel8\n" ) run_sh.write( "indexamajig -i {0} \\\n".format( chunk_lst_file ) ) run_sh.write( " --output={0} \\\n".format( stream_file ) ) run_sh.write( " --geometry={0} \\\n".format( geom_file ) ) run_sh.write( " --pdb={0} \\\n".format( cell_file ) ) run_sh.write( " --push-res=0.5 \\\n" ) run_sh.write( " --indexing=xgandalf-latt-cell \\\n" ) run_sh.write( " --peaks=peakfinder8 \\\n" ) run_sh.write( " --integration=rings-nocen-nograd \\\n" ) run_sh.write( " --tolerance=10.0,10.0,10.0,2,3,2 \\\n" ) run_sh.write( " --threshold={0} \\\n".format( threshold ) ) run_sh.write( " --min-snr=5 \\\n" ) run_sh.write( " --int-radius=4,6,7 \\\n" ) run_sh.write( " -j 32 \\\n" ) run_sh.write( " --multi \\\n" ) run_sh.write( " --check-peaks \\\n" ) run_sh.write( " --retry \\\n" ) run_sh.write( " --max-res=3000 \\\n" ) run_sh.write( " --min-pix-count=2 \\\n" ) run_sh.write( " --local-bg-radius=4 \\\n" ) run_sh.write( " --min-res=85" ) run_sh.close() # make file executable subprocess.call( [ "chmod", "+x", "{0}".format( cryst_run_file ) ] ) # return crystfel file name return cryst_run_file, stream_file def make_process_dir( proc_dir ): # make process directory try: os.makedirs( proc_dir ) except OSError as e: if e.errno != errno.EEXIST: logger.debug( "making directory error" ) raise def submit_job( job_file, reservation ): # submit the job if reservation: print( "using a ra beamtime reservation = {0}".format( reservation ) ) logger.info( "using ra reservation to process data = {0}".format( reservation ) ) submit_cmd = [ "sbatch", "-p", "hour", "--reservation={0}".format( reservation ), "--cpus-per-task=32", "--" , job_file ] else: submit_cmd = [ "sbatch", "-p", "hour", "--cpus-per-task=32", "--" , job_file ] logger.info( "using slurm command = {0}".format( submit_cmd ) ) try: job_output = subprocess.check_output( submit_cmd ) logger.info( "submited job = {0}".format( job_output ) ) except subprocess.CalledProcessError as e: print( "please give the correct ra reservation or remove the -v from the arguements" ) exit() # scrub job id from - example Submitted batch job 742403 pattern = r"Submitted batch job (\d+)" job_id = re.search( pattern, job_output.decode().strip() ).group(1) return int( job_id ) def wait_for_jobs( job_ids, total_jobs ): with tqdm( total=total_jobs, desc="Jobs Completed", unit="job" ) as pbar: while job_ids: completed_jobs = set() for job_id in job_ids: status_cmd = [ "squeue", "-h", "-j", str( job_id ) ] status = subprocess.check_output(status_cmd) if not status: completed_jobs.add(job_id) pbar.update(1) job_ids.difference_update(completed_jobs) time.sleep(2) def run_splits( list_df, cwd, name, geom_file, cell_file, threshold, reservation ): # set chunk counter chunk = 0 # submitted job set submitted_job_ids = set() # stream file list stream_lst = [] for chunk_lst in list_df: logger.info( "chunk {0} = {1} images".format( chunk, len( chunk_lst ) ) ) # define process directory proc_dir = "{0}/{1}/{1}_{2}".format( cwd, name, chunk ) # make process directory make_process_dir(proc_dir) # move to process directory os.chdir( proc_dir ) # write list to file chunk_lst_file = "{0}/{1}_{2}.lst".format( proc_dir, name, chunk ) chunk_lst.to_csv( chunk_lst_file, index=False, header=False ) # write crystfel file and append path to list cryst_run_file, stream_file = write_crystfel_run( proc_dir, name, chunk, chunk_lst_file, geom_file, cell_file, threshold ) stream_lst.append( "{0}/{1}".format( proc_dir, stream_file ) ) # submit jobs job_id = submit_job( cryst_run_file, reservation ) submitted_job_ids.add( job_id ) # increase chunk counter chunk = chunk +1 # move back to top dir os.chdir( cwd ) return submitted_job_ids, chunk, stream_lst def main( cwd, name, lst, chunk_size, geom_file, cell_file, threshold, reservation ): print( "reading SwissFEL lst file" ) print( "creating {0} image chunks of lst".format( chunk_size ) ) list_df = h5_split( lst, chunk_size ) print( "DONE" ) # run crystfel runs on individual splits print( "submitting jobs to cluster" ) submitted_job_ids, chunk, stream_lst = run_splits( list_df, cwd, name, geom_file, cell_file, threshold, reservation ) # monitor progress of jobs time.sleep( 30 ) wait_for_jobs( submitted_job_ids, chunk ) print( "done" ) # make composite .stream file output_file = "{0}.stream".format( name ) print( "concatenating .streams from separate runs." ) try: # Open the output file in 'append' mode with open(output_file, "a") as output: for file_name in stream_lst: try: with open(file_name, "r") as input_file: # Read the contents of the input file and append to the output file output.write(input_file.read()) except FileNotFoundError: logger.debug(f"File {file_name} not found. Skipping.") except IOError as e: logger.debug(f"An error occurred while appending files: {e}") print( "done" ) df, xtals, chunks = calculate_stats( output_file ) # stats index_rate = round( xtals/chunks*100, 2 ) mean_res, std_res = round( df.resolution.mean(), 2 ), round( df.resolution.std(), 2 ) median_res = df.resolution.median() mean_obs, std_obs = round( df.obs.mean(), 2 ), round( df.obs.std(), 2) mean_a, std_a = round( df.a.mean()*10, 2 ), round( df.a.std()*10, 2 ) mean_b, std_b = round( df.b.mean()*10, 2 ), round( df.b.std()*10, 2 ) mean_c, std_c = round( df.c.mean()*10, 2 ), round( df.c.std()*10, 2 ) mean_alpha, std_alpha = round( df.alpha.mean(), 2 ), round( df.alpha.std(), 2 ) mean_beta, std_beta = round(df.beta.mean(), 2 ), round( df.beta.std(), 2 ) mean_gamma, std_gamma = round( df.gamma.mean(), 2 ), round( df.gamma.std(), 2 ) logger.info( "images = {0}".format( chunks ) ) logger.info( "crystals = {0}".format( xtals ) ) logger.info( "indexing rate = {0} %".format( index_rate ) ) logger.info( "mean resolution = {0} +/- {1} A".format( mean_res, std_res ) ) logger.info( "median resolution = {0} A".format( median_res ) ) logger.info( "mean observations = {0} +/- {1}".format( mean_obs, std_obs ) ) logger.info( "mean a = {0} +/- {1} A".format( mean_a, std_a ) ) logger.info( "mean b = {0} +/- {1} A".format( mean_b, std_b ) ) logger.info( "mean c = {0} +/- {1} A".format( mean_c, std_c ) ) logger.info( "mean alpha = {0} +/- {1} deg".format( mean_alpha, std_alpha ) ) logger.info( "mean beta = {0} +/- {1} deg".format( mean_beta, std_beta ) ) logger.info( "mean gamma = {0} +/- {1} deg".format( mean_gamma, std_gamma ) ) print( "printing stats" ) print( "images = {0}".format( chunks ) ) print( "crystals = {0}".format( xtals ) ) print( "indexing rate = {0} %".format( index_rate ) ) print( "mean resolution = {0} +/- {1} A".format( mean_res, std_res ) ) print( "median resolution = {0} A".format( median_res ) ) print( "mean observations = {0} +/- {1}".format( mean_obs, std_obs ) ) print( "mean a = {0} +/- {1} A".format( mean_a, std_a ) ) print( "mean b = {0} +/- {1} A".format( mean_b, std_b ) ) print( "mean c = {0} +/- {1} A".format( mean_c, std_c ) ) print( "mean alpha = {0} +/- {1} deg".format( mean_alpha, std_alpha ) ) print( "mean beta = {0} +/- {1} deg".format( mean_beta, std_beta ) ) print( "mean gamma = {0} +/- {1} deg".format( mean_gamma, std_gamma ) ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-l", "--lst_file", help="file from SwissFEL output to be processed quickly. Requried.", type=os.path.abspath, required=True ) parser.add_argument( "-k", "--chunk_size", help="how big should each chunk be? - the bigger the chunk, the fewer jobs, the slower it will be. Default = 500.", type=int, default=500 ) parser.add_argument( "-g", "--geom_file", help="path to geom file to be used in the refinement. Required.", type=os.path.abspath, required=True ) parser.add_argument( "-c", "--cell_file", help="path to cell file of the crystals used in the refinement. Required.", type=os.path.abspath, required=True ) parser.add_argument( "-n", "--job_name", help="the name of the job to be done. Default = 'split_###'", type=str, default="split" ) parser.add_argument( "-v", "--reservation", help="reservation name for ra cluster. Usually along the lines of P11111_2024-12-10", type=str, default=None ) parser.add_argument( "-p", "--photons_or_energy", help="determines the threshold to use for CrystFEL. Photons counts have always been used in Cristallina and are now used on Alvra from 01.11.2024. Please use 'energy' for Alvra before this.", type=str, default="photons" ) parser.add_argument( "-d", "--debug", help="output debug to terminal.", type=bool, default=False ) args = parser.parse_args() # set current working directory cwd = os.getcwd() # set loguru if not args.debug: logger.remove() logfile = "{0}.log".format( args.job_name ) logger.add( logfile, format="{message}", level="INFO") # log geometry file geom = open( args.geom_file, "r" ).read() logger.info( geom ) # set threshold based on detector if args.photons_or_energy == "energy": threshold = 3000 elif args.photons_or_energy == "photons": threshold = 15 main( cwd, args.job_name, args.lst_file, args.chunk_size, args.geom_file, args.cell_file, threshold, args.reservation )