script to select images by resolution
This commit is contained in:
329
reduction_tools/stream_select_res.py
Normal file
329
reduction_tools/stream_select_res.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# author J.Beale
|
||||
|
||||
"""
|
||||
# aim
|
||||
analyses and selects crystals based on their reported resolution from crystfel
|
||||
|
||||
# usage
|
||||
python stream_random.py -s <path to stream>
|
||||
-o output file names
|
||||
-p plots histogram of images resolution
|
||||
-r selects all images with higher resolution than value
|
||||
|
||||
# output
|
||||
either
|
||||
- histogram of images by resolution
|
||||
- .stream file of selected images
|
||||
"""
|
||||
|
||||
# modules
|
||||
import re
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def extract_chunks( input_file ):
|
||||
|
||||
# setup
|
||||
chunk_df = pd.DataFrame()
|
||||
image_no = []
|
||||
chunks = []
|
||||
hits = []
|
||||
collect_lines = False
|
||||
# Open the input file for reading
|
||||
with open(input_file, 'r') as f:
|
||||
for line in f:
|
||||
|
||||
# Check for the start condition
|
||||
if line.startswith('----- Begin chunk -----'):
|
||||
hit = False
|
||||
collect_lines = True
|
||||
chunk_lines = []
|
||||
if collect_lines:
|
||||
chunk_lines.append(line)
|
||||
|
||||
# find image_no
|
||||
if line.startswith( "Event:" ):
|
||||
image_search = re.findall( r"Event: //(\d+)", line )
|
||||
image = int(image_search[0])
|
||||
image_no.append( image )
|
||||
|
||||
# is there a hit in chunk
|
||||
if line.startswith( "Cell parameters" ):
|
||||
hit = True
|
||||
|
||||
if line.startswith('----- End chunk -----'):
|
||||
collect_lines = False # Stop collecting lines
|
||||
chunks.append( chunk_lines )
|
||||
hits.append( hit )
|
||||
|
||||
chunk_df[ "chunks" ] = chunks
|
||||
chunk_df[ "image_no" ] = image_no
|
||||
chunk_df[ "hit" ] = hits
|
||||
|
||||
return chunk_df
|
||||
|
||||
def scrub_res( line ):
|
||||
|
||||
# get resolution
|
||||
try:
|
||||
pattern = r"diffraction_resolution_limit\s=\s\d\.\d+\snm\^-1\sor\s(\d+\.\d+)\sA"
|
||||
res = re.search( pattern, line ).group(1)
|
||||
except AttributeError as e:
|
||||
res = np.nan
|
||||
|
||||
return float( res )
|
||||
|
||||
def extract_xtals( chunk ):
|
||||
|
||||
# setup
|
||||
xtals = []
|
||||
resolutions = []
|
||||
collect_crystal_lines = False
|
||||
# Open the input file for reading
|
||||
for line in chunk:
|
||||
|
||||
# Check for the xtals start condition
|
||||
if line.startswith('--- Begin crystal'):
|
||||
collect_crystal_lines = True
|
||||
xtal_lines = []
|
||||
if collect_crystal_lines:
|
||||
xtal_lines.append(line)
|
||||
if line.startswith('--- End crystal\n'):
|
||||
collect_crystal_lines = False # Stop collecting lines
|
||||
xtals.append( xtal_lines )
|
||||
if line.startswith( "diffraction_resolution_limit" ):
|
||||
res = scrub_res( line )
|
||||
resolutions.append( res )
|
||||
|
||||
return xtals, resolutions
|
||||
|
||||
def extract_header( chunk ):
|
||||
|
||||
# setup
|
||||
header = []
|
||||
collect_header_lines = False
|
||||
# Open the input file for reading
|
||||
for line in chunk:
|
||||
|
||||
# Check for the xtals start condition
|
||||
if line.startswith('----- Begin chunk -----'):
|
||||
collect_header_lines = True
|
||||
header_lines = []
|
||||
if collect_header_lines:
|
||||
header_lines.append(line)
|
||||
if line.startswith('End of peak list'):
|
||||
collect_header_lines = False # Stop collecting lines
|
||||
header.append( header_lines )
|
||||
|
||||
return header
|
||||
|
||||
def get_header( header, input_file ):
|
||||
|
||||
if header == "geom":
|
||||
start_keyword = "----- Begin geometry file -----"
|
||||
end_keyword = "----- End geometry file -----"
|
||||
if header == "cell":
|
||||
start_keyword = "----- Begin unit cell -----"
|
||||
end_keyword = "----- End unit cell -----"
|
||||
|
||||
# setup
|
||||
collect_lines = False
|
||||
headers = []
|
||||
|
||||
# Open the input file for reading
|
||||
with open(input_file, 'r') as f:
|
||||
for line in f:
|
||||
# Check for the start condition
|
||||
if line.strip() == start_keyword:
|
||||
collect_lines = True
|
||||
headers_lines = []
|
||||
# Collect lines between start and end conditions
|
||||
if collect_lines:
|
||||
headers_lines.append(line)
|
||||
# Check for the end condition
|
||||
if line.strip() == end_keyword:
|
||||
collect_lines = False # Stop collecting lines
|
||||
headers.append(headers_lines)
|
||||
|
||||
return headers[0]
|
||||
|
||||
def write_to_file( geom, cell, chunk_header, crystals, output_file ):
|
||||
|
||||
# Write sections with matching cell parameters to the output file
|
||||
with open(output_file, 'w') as out_file:
|
||||
out_file.write('CrystFEL stream format 2.3\n')
|
||||
out_file.write('Generated by CrystFEL 0.10.2\n')
|
||||
out_file.writelines(geom)
|
||||
out_file.writelines(cell)
|
||||
for crystal, header in zip( crystals, chunk_header ):
|
||||
out_file.writelines( header )
|
||||
out_file.writelines( crystal )
|
||||
out_file.writelines( "----- End chunk -----\n" )
|
||||
|
||||
def sort_xtals( chunk_df ):
|
||||
|
||||
# extract xtals
|
||||
xtal_df = pd.DataFrame()
|
||||
counter = 0
|
||||
for index, row in chunk_df.iterrows():
|
||||
|
||||
chunk, hit, image_no = row[ "chunks" ], row[ "hit" ], row[ "image_no" ]
|
||||
|
||||
if hit:
|
||||
|
||||
# find xtals and header
|
||||
header = extract_header( chunk )
|
||||
xtals, resolutions = extract_xtals( chunk )
|
||||
|
||||
# make header same length as xtals
|
||||
header = header*len(xtals)
|
||||
|
||||
# concat results
|
||||
xtal_df_1 = pd.DataFrame()
|
||||
xtal_df_1[ "header" ] = header
|
||||
xtal_df_1[ "xtals" ] = xtals
|
||||
xtal_df_1[ "image_no" ] = image_no
|
||||
xtal_df_1[ "resolution" ] = resolutions
|
||||
xtal_df = pd.concat( ( xtal_df, xtal_df_1 ) )
|
||||
|
||||
# add count and print every 1000s
|
||||
counter = counter + len(xtals)
|
||||
if counter % 1000 == 0:
|
||||
print( counter, end='\r' )
|
||||
print( "done" )
|
||||
|
||||
# sort by image no and reindex
|
||||
xtal_df = xtal_df.sort_values( by=[ "image_no" ] )
|
||||
xtal_df = xtal_df.reset_index( drop=True )
|
||||
|
||||
return xtal_df
|
||||
|
||||
def plot_res_histogram( res_df, res_median, res_q1, res_q3 ):
|
||||
|
||||
# calculate relative numbers of bins
|
||||
bin_range = 0.25
|
||||
res_min = res_df.min().values[0]
|
||||
res_max = res_df.max().values[0]
|
||||
|
||||
q1_bins = round( ( res_q1 - res_min )/bin_range )
|
||||
q2_bins = round( ( res_median - res_q1 )/bin_range )
|
||||
q3_bins = round( ( res_q3 - res_median )/bin_range )
|
||||
q4_bins = round( ( res_max - res_q3 )/bin_range )
|
||||
|
||||
# cut data by quantile
|
||||
df_q1 = res_df[ res_df.resolution <= res_q1 ]
|
||||
df_q2 = res_df[ ( res_df.resolution > res_q1 ) & ( res_df.resolution <= res_median ) ]
|
||||
df_q3 = res_df[ ( res_df.resolution > res_median ) & ( res_df.resolution <= res_q3 ) ]
|
||||
df_q4 = res_df[ res_df.resolution > res_q3 ]
|
||||
|
||||
# plot histogram of resolution
|
||||
fig, axs = plt.subplots()
|
||||
|
||||
axs.hist( df_q1, bins=q1_bins, rwidth=1, stacked=True, color="blue", label="x<q1" )
|
||||
axs.hist( df_q2, bins=q2_bins, rwidth=1, stacked=True, color="red", label="q1<x<q2")
|
||||
axs.hist( df_q3, bins=q3_bins, rwidth=1, stacked=True, color="green", label="q2<x<q3" )
|
||||
axs.hist( df_q4, bins=q4_bins, rwidth=1, stacked=True, color="purple", label="q3<x<q4" )
|
||||
|
||||
axs.axvline( x=res_median, color="black", linestyle="dashed", label="median = {0}".format( res_median ) )
|
||||
axs.set_xlabel( "resolution" )
|
||||
axs.set_ylabel( "frequency" )
|
||||
axs.legend()
|
||||
bin_size = round( ( len( df_q1 ) + len( df_q2 ) + len( df_q3 ) + len( df_q4 ) )/4 )
|
||||
axs.text( 0.98, 0.7,
|
||||
"images per quartile = {0}".format( bin_size ),
|
||||
ha="right",
|
||||
va="top",
|
||||
transform=axs.transAxes
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
|
||||
def main( input_file, output, plotter, resolution ):
|
||||
|
||||
# get geom and cell file headers
|
||||
print( "getting header info from .stream file" )
|
||||
geom = get_header( "geom", input_file )
|
||||
cell = get_header( "cell", input_file )
|
||||
print( "done" )
|
||||
|
||||
# extract chunks
|
||||
print( "finding chucks" )
|
||||
chunk_df = extract_chunks( input_file )
|
||||
# display no. of chunks
|
||||
print( "found {0} chunks".format( len(chunk_df) ) )
|
||||
# remove rows without xtals
|
||||
chunk_df = chunk_df.loc[chunk_df.hit, :]
|
||||
print( "found {0} hits (not including multiples)".format( len(chunk_df) ) )
|
||||
print( "done" )
|
||||
|
||||
print( "sorting xtals from chunks" )
|
||||
xtal_df = sort_xtals( chunk_df )
|
||||
print( "done" )
|
||||
|
||||
print( "calculate stats" )
|
||||
res_df = xtal_df[ [ "resolution" ] ]
|
||||
res_median = res_df.median().values[0]
|
||||
res_q1 = res_df.quantile( 0.25 ).values[0]
|
||||
res_q3 = res_df.quantile( 0.75 ).values[0]
|
||||
print( "median resolution and range = {0} ({1}-{2})".format( res_median, res_q1, res_q3 ) )
|
||||
print( "done" )
|
||||
|
||||
if plotter == True:
|
||||
print( "plot image resolution histogram" )
|
||||
plot_res_histogram( res_df, res_median, res_q1, res_q3 )
|
||||
|
||||
if resolution:
|
||||
|
||||
print( "finding images with resolution greater than {0}".format( resolution ) )
|
||||
select_df = xtal_df[ xtal_df[ "resolution" ] <= resolution ]
|
||||
print( "done" )
|
||||
|
||||
print( "writing {0} to output file".format( len( select_df ) ) )
|
||||
crystals = select_df.xtals.to_list()
|
||||
chunk_header = select_df.header.to_list()
|
||||
output_file = "{0}.stream".format( output )
|
||||
write_to_file( geom, cell, chunk_header, crystals, output_file )
|
||||
print( "done" )
|
||||
|
||||
else:
|
||||
print( "no output. please use either the plot-histogram or resolution functions of the script" )
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--stream",
|
||||
help="input stream file",
|
||||
required=True,
|
||||
type=os.path.abspath
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
help="output stream file name. '.stream will be added'",
|
||||
type=str,
|
||||
default="selected"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plot_histogram",
|
||||
help="plots a histogram of the crystfel calculated resolutions for inspection",
|
||||
type=bool,
|
||||
default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resolution",
|
||||
help="upper resolution limit. Will take all images below this.",
|
||||
type=float
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# run main
|
||||
main( args.stream, args.output, args.plot_histogram, args.resolution )
|
||||
Reference in New Issue
Block a user