diff --git a/reduction_tools/stream_random.py b/reduction_tools/stream_random.py index 2a77110..ca2f97e 100644 --- a/reduction_tools/stream_random.py +++ b/reduction_tools/stream_random.py @@ -24,21 +24,6 @@ import pandas as pd import numpy as np import os -def scrub_cells( stream ): - - # get uc values from stream file - # example - Cell parameters 7.71784 7.78870 3.75250 nm, 90.19135 90.77553 90.19243 deg - # scrub clen and return - else nan - try: - pattern = r"Cell\sparameters\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\snm,\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\sdeg" - cell_lst = re.findall( pattern, stream ) - xtals = len( cell_lst ) - if AttributeError: - return cell_lst, xtals - except AttributeError: - logger.debug( "scrub_cells error" ) - return np.nan - def extract_chunks( input_file ): # setup @@ -78,10 +63,6 @@ def extract_chunks( input_file ): chunk_df[ "image_no" ] = image_no chunk_df[ "hit" ] = hits - # sort values and set image_no as index - chunk_df = chunk_df.sort_values( "image_no" ) - chunk_df = chunk_df.set_index( "image_no" ) - return chunk_df def extract_xtals( chunk ): @@ -124,19 +105,69 @@ def extract_header( chunk ): return header +def get_header( header, input_file ): -def main( input_file ): + if header == "geom": + start_keyword = "----- Begin geometry file -----" + end_keyword = "----- End geometry file -----" + if header == "cell": + start_keyword = "----- Begin unit cell -----" + end_keyword = "----- End unit cell -----" + # setup + collect_lines = False + headers = [] + + # Open the input file for reading + with open(input_file, 'r') as f: + for line in f: + # Check for the start condition + if line.strip() == start_keyword: + collect_lines = True + headers_lines = [] + # Collect lines between start and end conditions + if collect_lines: + headers_lines.append(line) + # Check for the end condition + if line.strip() == end_keyword: + collect_lines = False # Stop collecting lines + headers.append(headers_lines) + + return headers[0] + +def write_to_file( geom, cell, chunk_header, crystals, output_file ): + + # Write sections with matching cell parameters to the output file + with open(output_file, 'w') as out_file: + out_file.write('CrystFEL stream format 2.3\n') + out_file.write('Generated by CrystFEL 0.10.2\n') + out_file.writelines(geom) + out_file.writelines(cell) + for crystal, header in zip( crystals, chunk_header ): + out_file.writelines( header ) + out_file.writelines( crystal ) + out_file.writelines( "----- End chunk -----\n" ) + +def main( input_file, samples, output, repeat ): + + # get geom and cell file headers + print( "getting header info from .stream file" ) + geom = get_header( "geom", input_file ) + cell = get_header( "cell", input_file ) + print( "done" ) + # extract chunks print( "finding chucks" ) chunk_df = extract_chunks( input_file ) # display no. of chunks print( "found {0} chunks".format( len(chunk_df) ) ) - print( "found {0} crystals".format( chunk_df.hits.sum() ) ) + # remove rows without xtals + chunk_df = chunk_df.loc[chunk_df.hit, :] + print( "found {0} hits (not including multiples)".format( len(chunk_df) ) ) print( "done" ) # extract xtals - print( "geting xtal data from from chunks" ) + print( "get xtals from chunks" ) xtal_df = pd.DataFrame() counter = 0 for index, row in chunk_df.iterrows(): @@ -165,6 +196,28 @@ def main( input_file ): print( counter, end='\r' ) print( "done" ) + # sort by image no and reindex + xtal_df = xtal_df.sort_values( by=[ "image_no" ] ) + xtal_df = xtal_df.reset_index( drop=True ) + + # randomly n number of sample of xtals + for sample in samples: + print( "taking {0} {1} sample".format( repeat, sample ) ) + for x in range( 0, repeat ): + + try: + sample_df = xtal_df.sample( sample ) + except ValueError: + print( "input image sample larger than number of hits. Sample should be less than {0}".format( len(xtal_df) ) ) + + # rebuild .stream from sample + print( "writing {0} to output file".format( x ) ) + crystals = sample_df.xtals.to_list() + chunk_header = sample_df.header.to_list() + output_file = "{0}_{1}_{2}.stream".format( output, sample ,x ) + write_to_file( geom, cell, chunk_header, crystals, output_file ) + print( "done {0}".format( x ) ) + print( "done" ) def list_of_floats(arg): return list(map(int, arg.split(','))) @@ -178,6 +231,31 @@ if __name__ == "__main__": required=True, type=os.path.abspath ) + parser.add_argument( + "-o", + "--output", + help="output stream file with sampled xtals", + type=str, + default="sample" + ) + parser.add_argument( + "-n", + "--sample", + help="size of sample to take from input.stream", + type=list_of_floats, + required=True + ) + parser.add_argument( + "-r", + "--repeat", + help="how many samples would you like?", + type=int + ) args = parser.parse_args() + # does if need to be run multiple times? + if args.repeat is None: + repeat = 1 + else: + repeat = args.repeat # run main - main( args.stream ) + main( args.stream, args.sample, args.output, repeat) \ No newline at end of file