diff --git a/reduction_tools/stream_random.py b/reduction_tools/stream_random.py new file mode 100644 index 0000000..2a77110 --- /dev/null +++ b/reduction_tools/stream_random.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 + +# author J.Beale + +""" +# aim +randomly select a series of crystals from a stream file and +then compile them into the correctly formated .stream + +# usage +python stream_random.py -s + -o output file names + -n sample size + -r how many repeat random samples do you want? + +# output +.stream file with random sample of xtals +""" + +# modules +import re +import argparse +import pandas as pd +import numpy as np +import os + +def scrub_cells( stream ): + + # get uc values from stream file + # example - Cell parameters 7.71784 7.78870 3.75250 nm, 90.19135 90.77553 90.19243 deg + # scrub clen and return - else nan + try: + pattern = r"Cell\sparameters\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\snm,\s(\d+\.\d+)\s(\d+\.\d+)\s(\d+\.\d+)\sdeg" + cell_lst = re.findall( pattern, stream ) + xtals = len( cell_lst ) + if AttributeError: + return cell_lst, xtals + except AttributeError: + logger.debug( "scrub_cells error" ) + return np.nan + +def extract_chunks( input_file ): + + # setup + chunk_df = pd.DataFrame() + image_no = [] + chunks = [] + hits = [] + collect_lines = False + # Open the input file for reading + with open(input_file, 'r') as f: + for line in f: + + # Check for the start condition + if line.startswith('----- Begin chunk -----'): + hit = False + collect_lines = True + chunk_lines = [] + if collect_lines: + chunk_lines.append(line) + + # find image_no + if line.startswith( "Event:" ): + image_search = re.findall( r"Event: //(\d+)", line ) + image = int(image_search[0]) + image_no.append( image ) + + # is there a hit in chunk + if line.startswith( "Cell parameters" ): + hit = True + + if line.startswith('----- End chunk -----'): + collect_lines = False # Stop collecting lines + chunks.append( chunk_lines ) + hits.append( hit ) + + chunk_df[ "chunks" ] = chunks + chunk_df[ "image_no" ] = image_no + chunk_df[ "hit" ] = hits + + # sort values and set image_no as index + chunk_df = chunk_df.sort_values( "image_no" ) + chunk_df = chunk_df.set_index( "image_no" ) + + return chunk_df + +def extract_xtals( chunk ): + + # setup + xtals = [] + collect_crystal_lines = False + # Open the input file for reading + for line in chunk: + + # Check for the xtals start condition + if line.startswith('--- Begin crystal'): + collect_crystal_lines = True + xtal_lines = [] + if collect_crystal_lines: + xtal_lines.append(line) + if line.startswith('--- End crystal\n'): + collect_crystal_lines = False # Stop collecting lines + xtals.append( xtal_lines ) + + return xtals + +def extract_header( chunk ): + + # setup + header = [] + collect_header_lines = False + # Open the input file for reading + for line in chunk: + + # Check for the xtals start condition + if line.startswith('----- Begin chunk -----'): + collect_header_lines = True + header_lines = [] + if collect_header_lines: + header_lines.append(line) + if line.startswith('End of peak list'): + collect_header_lines = False # Stop collecting lines + header.append( header_lines ) + + return header + + +def main( input_file ): + + # extract chunks + print( "finding chucks" ) + chunk_df = extract_chunks( input_file ) + # display no. of chunks + print( "found {0} chunks".format( len(chunk_df) ) ) + print( "found {0} crystals".format( chunk_df.hits.sum() ) ) + print( "done" ) + + # extract xtals + print( "geting xtal data from from chunks" ) + xtal_df = pd.DataFrame() + counter = 0 + for index, row in chunk_df.iterrows(): + + chunk, hit, image_no = row[ "chunks" ], row[ "hit" ], row[ "image_no" ] + + if hit: + + # find xtals and header + header = extract_header( chunk ) + xtals = extract_xtals( chunk ) + + # make header same length as xtals + header = header*len(xtals) + + # concat results + xtal_df_1 = pd.DataFrame() + xtal_df_1[ "header" ] = header + xtal_df_1[ "xtals" ] = xtals + xtal_df_1[ "image_no" ] = image_no + xtal_df = pd.concat( ( xtal_df, xtal_df_1 ) ) + + # add count and print every 1000s + counter = counter + len(xtals) + if counter % 1000 == 0: + print( counter, end='\r' ) + print( "done" ) + + +def list_of_floats(arg): + return list(map(int, arg.split(','))) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-s", + "--stream", + help="input stream file", + required=True, + type=os.path.abspath + ) + args = parser.parse_args() + # run main + main( args.stream )