from ijutils import * import java.lang.reflect import flanagan.complex.ComplexMatrix as ComplexMatrix import flanagan.math.Matrix as Matrix import flanagan.complex.Complex as Complex import org.jtransforms.fft.DoubleFFT_2D as DoubleFFT_2D import math from startup import ScriptUtils import ij.plugin.filter.PlugInFilterRunner as PlugInFilterRunner import ij.plugin.filter.ExtendedPlugInFilter as ExtendedPlugInFilter import ij.plugin.filter.ExtendedPlugInFilter as ExtendedPlugInFilter import java.lang.Thread as Thread def new_array(type, *dimensions): return java.lang.reflect.Array.newInstance(ScriptUtils.getPrimitiveType(type), *dimensions) def load_stack(title, file_list, show=False): ip_list = [] for f in file_list: ip_list.append(open_image(expand_path(f))) stack = create_stack(ip_list, title=title) if show: stack.show() return stack def load_test_stack(title="Test", show=False, size=9): file_list = [] for index in range(40, 40+size): file_list.append("{images}/TestObjAligner/i210517_0" + str(index) + "#001.tif") return load_stack(title, file_list, show) def load_corr_stack(title="Corr", show=False): file_list = [] for index in range(40, 49): file_list.append("{images}/TestObjAligner_corr/i210517_0" + str(index) + "#001.tif") return load_stack(title, file_list, show) def complex_edge_filtering(imp, complex=True, g_sigma=3.0, g_resolution=1e-4, show=False, java_code=False): if java_code: get_context().getPluginManager().loadInitializePlugin("Align_ComplexEdgeFiltering.java") complex_edge_filter = get_context().getClassByName("Align_ComplexEdgeFiltering").newInstance() complex_edge_filter.setup(str(g_sigma)+","+str(complex)+","+str(show), imp) #Gaussian blur radius, Complex (True) or Real (False), show dialog = False complex_edge_filter.run(imp.getProcessor()) return complex_edge_filter.output gb = GaussianBlur() sobel_r = [1, 0, -1, 2, 0, -2, 1, 0, -1] sobel_i = [1, 2, 1, 0, 0, 0, -1, -2, -1] imp_r = imp.createImagePlus() stack_r = ImageStack(imp.getWidth(), imp.getHeight()) if (complex): imp_i = imp.createImagePlus() stack_i = ImageStack(imp.getWidth(), imp.getHeight()) for i in range(1, imp.getImageStackSize() + 1): ip_r = imp.getStack().getProcessor(i).duplicate().convertToFloat() # Gaussian blurring gb.blurGaussian(ip_r, g_sigma, g_sigma, g_resolution) ip_i = ip_r.duplicate() # Sobel edge filtering ip_r.convolve3x3(sobel_r) ip_i.convolve3x3(sobel_i) stack_r.addSlice(imp.getStack().getSliceLabel(i), ip_r) stack_i.addSlice(imp.getStack().getSliceLabel(i), ip_i) IJ.showProgress(i, imp.getImageStackSize()) # imag imp_i.setStack("EdgeImag_" + imp.getTitle(), stack_i); imp_i.resetDisplayRange() if show: imp_i.show() imp_i.updateAndDraw() else: imp_i = None for i in range(1, imp.getImageStackSize() + 1): ip_r = imp.getStack().getProcessor(i).duplicate().convertToFloat() # Gaussian blurring gb.blurGaussian(ip_r, g_sigma, g_sigma, g_resolution) # Sobel edge filtering ip_r.filter(ImageProcessor.FIND_EDGES) stack_r.addSlice(imp.getStack().getSliceLabel(i), ip_r) IJ.showProgress(i, imp.getImageStackSize()) # real imp_r.setStack("EdgeReal_" + imp.getTitle(), stack_r) imp_r.resetDisplayRange() if show: imp_r.show() imp_r.updateAndDraw() return [imp_r, imp_i] class TranslationFilter(ExtendedPlugInFilter): def __init__(self): self.shifts=None self.flags = (self.DOES_ALL-self.DOES_RGB)|self.DOES_STACKS|self.NO_CHANGES|self.FINAL_PROCESSING self.imp=None self.output = None self.translated = None self.pifr = None self.nbslices = 0 self.processed = 0 def setup(self, arg, imp): if "final"==arg: self.output.setStack("REG_" + self.imp.getTitle(), self.translated) return self.DONE else: if self.imp is None: self.imp = imp; return self.flags; def showDialog(self,imp, command, pfr): self.pifr = pfr return flags # Called by ImageJ to set the number of calls to run(ip) corresponding to 100% of the progress bar def setNPasses(self, nPasses): self.nbslices = nPasses; self.output = self.imp.createImagePlus(); self.translated = ImageStack(self.imp.getWidth(), self.imp.getHeight(), self.nbslices) #Process a FloatProcessor (with the CONVERT_TO_FLOAT flag, ImageJ does the conversion to float). # Called by ImageJ for each stack slice (when processing a full stack); for RGB also called once for each color. */ def run(self, ip): if Thread.currentThread().isInterrupted(): return thisone = self.pifr.getSliceNumber() nip = ip.duplicate().convertToFloat() nip.setInterpolationMethod(ImageProcessor.BICUBIC) if len(self.shifts) != self.nbslices: xoff, yoff = self.shifts[1][3], self.shifts[1][2] # translate all the frame by the same shifts else: xoff, yoff = self.shifts[thisone-1][3], self.shifts[thisone-1][2] nip.translate(xoff, yoff) lbl = self.imp.getStack().getSliceLabel(thisone) if lbl != None: self.translated.addSlice(lbl, nip, thisone - 1) else: self.translated.addSlice("" + thisone, nip, thisone - 1) self.translated.deleteSlice(thisone + 1) self.processed+=1 IJ.showProgress(self.processed, self.nbslices); def translate(stack, shifts, show=False, java_code=False): WindowManager.setTempCurrentImage(stack) if java_code: get_context().getPluginManager().loadInitializePlugin("Align_TranslationFilter.java") translation_filter = get_context().getClassByName("Align_TranslationFilter").newInstance() translation_filter.imp = imp translation_filter.shifts = shifts pfr = PlugInFilterRunner(translation_filter, "", "" ) ret = translation_filter.output else: translation_filter = TranslationFilter() translation_filter.shifts = shifts translation_filter.imp = stack pfr = PlugInFilterRunner(translation_filter, "", "" ) ret = translation_filter.output if show: ret.show() ret.updateAndDraw() return ret def load_shifts(filename): get_context().getPluginManager().loadInitializePlugin("ShiftsIO.java") sio = get_context().getClassByName("ShiftsIO").newInstance() return sio.load(expand_path(filename), "directshifts") def save_shifts(filename, shifts): get_context().getPluginManager().loadInitializePlugin("ShiftsIO.java") sio = get_context().getClassByName("ShiftsIO").newInstance() sio.save(expand_path(filename), shifts, "directshifts") def ip_to_fft_array_2d(ip): pixels = ip.getPixels() w = ip.getWidth() h = ip.getHeight() data = new_array('d', h, w) # new double[h][w] for j in range(h): # (int j = 0; j < h; j++) for i in range(w): # for (int i = 0; i < w; i++) data[j][i] = pixels[j * w + i] return data def ip_to_fft_complex_array_2d(ip_r, ip_i): pixels_r = ip_r.getPixels() pixels_i = ip_i.getPixels() w = ip_r.getWidth() h = ip_r.getHeight() data = new_array('d', h, 2 * w) # new double[h][2*w]; for j in range(h): # (int j = 0; j < h; j++) for i in range(w): # for (int i = 0; i < w; i++) data[j][2 * i] = pixels_r[j * w + i] data[j][2 * i + 1] = pixels_i[j * w + i]; return data def fft_array_2d_to_complex_matrix(data, h, w): m = ComplexMatrix(h,w) for j in range(h): #for (int j = 0; j < h; j++) { for i in range(w/2): # for (int i = 0; i <= w/2; i++) { if (j > 0) and (i > 0) and (i < w/2): m.setElement(j, i, Complex(data[j][2*i], data[j][2*i+1])) m.setElement(h-j, w-i, Complex(data[j][2*i], -data[j][2*i+1])) if (j == 0) and (i > 0) and (i < w/2): m.setElement(0, i, Complex(data[0][2*i], data[0][2*i+1])) m.setElement(0, w-i, Complex(data[0][2*i], -data[0][2*i+1])) if (i == 0) and (j > 0) and (j < h/2): m.setElement(j,0, Complex(data[j][0], data[j][1])) m.setElement(h-j, 0, Complex(data[j][0], -data[j][1])) m.setElement(j, w/2, Complex(data[h-j][1], -data[h-j][0])) m.setElement(h-j, w/2, Complex(data[h-j][1], data[h-j][0])) if (j == 0) and (i == 0): m.setElement(0, 0, Complex(data[0][0], 0)); if (j == 0) and (i == w/2): m.setElement(0, w/2, Complex(data[0][1], 0)); if (j == h/2) and (i == 0): m.setElement(h/2, 0, Complex(data[h/2][0], 0)); if (j == h/2) and (i == w/2): m.setElement(h/2, w/2, Complex(data[h/2][1], 0)); return m def fft_complex_array_2d_to_complex_matrix(data, h, w): m = ComplexMatrix(h,w); for j in range(h): #for (int j = 0; j < h; j++) { for i in range(w): # for (int i = 0; i < w; i++) { m.setElement(j,i, Complex(data[j][2*i], data[j][2*i+1])) return m def complex_matrix_to_fft_array_2d(m): w = m.getNcol() h = m.getNrow() data = new_array('d', h,w) #new double[h][w]; for j in range(h): #for (int j = 0; j < h; j++) { for i in range(w): #for (int i = 0; i <= w/2; i++) { if (j > 0) and (i > 0) and (i < w/2): data[j][2*i] = m.getElementReference(j,i).getReal() data[j][2*i+1] = m.getElementReference(j,i).getImag() if (j == 0) and (i > 0) and (i < w/2): data[0][2*i] = m.getElementReference(0,i).getReal() data[0][2*i+1] = m.getEementReference(0,i).getImag() if (i == 0) and (j > 0) and (j < h/2): data[j][0] = m.getElementReference(j,0).getReal() data[j][1] = m.getElementReference(j,0).getImag() data[h-j][1] = m.getElementReference(j,w/2).getReal() data[h-j][0] = m.getElementReference(h-j,w/2).getImag() if (j == 0) and (i == 0): data[0][0] = m.getElementReference(0,0).getReal() if (j == 0) and (i == w/2): data[0][1] = m.getElementReference(0,w/2).getReal() if (j == h/2) and (i == 0): data[h/2][0] = m.getElementReference(h/2,0).getReal() if (j == h/2) and ( i == w/2): data[h/2][1] = m.getElementReference(h/2,w/2).getReal() return data # convert a Complex Matrix into an 2d real part array data[0][][] and 2d imaginary part data[1][][] def complex_matrix_to_real_array_2d(m): w = m.getNcol() h = m.getNrow() data = new_array('d', 2,h,w) #new double[2][h][w]; for j in range(h): #for (int j = 0; j < h; j++) { for i in range(w): #for (int i = 0; i < w; i++) { data[0][j][i] = m.getElementReference(j,i).getReal() data[1][j][i] = m.getElementReference(j,i).getImag() return data; def compute_fft(imp_r, imp_i, roi): slices = imp_r.getStackSize() ffts = java.lang.reflect.Array.newInstance(ComplexMatrix, slices) # new ComplexMatrix[slices] for i in range(1, slices + 1): if imp_i is None: ip = imp_r.getStack().getProcessor(i) ip.setRoi(roi) curr = ip.crop().convertToFloat(); data = ip_to_fft_array_2d(curr) ffts[i - 1] = fft2(data) else: ip1 = imp_r.getStack().getProcessor(i) ip1.setRoi(roi) curr_r = ip1.crop().convertToFloat() ip2 = imp_i.getStack().getProcessor(i) ip2.setRoi(roi) curr_i = ip2.crop().convertToFloat() data = ip_to_fft_complex_array_2d(curr_r, curr_i) ffts[i - 1] = cfft2(data) IJ.showProgress(i, slices) return ffts def element_product(a, b): nr = a.getNrow() nc = a.getNcol() res = ComplexMatrix(nr, nc) for j in range(nr): # (int j = 0; j < nr; j++) { for i in range(nc): # (int i = 0; i < nc; i++) { res.setElement(j, i, a.getElementReference(j, i).times(b.getElementReference(j, i))) return res; def fft_shift(complex_matrix): nc = complex_matrix.getNcol() nr = complex_matrix.getNrow() out = ComplexMatrix(nr, nc) midi = int(math.floor(nc / 2.0)) offi = int(math.ceil(nc / 2.0)) midj = int(math.floor(nr / 2.0)) offj = int(math.ceil(nr / 2.0)) for j in range(nr): # for (int j = 0; j < nr; j ++){ for i in range(nc): # for (int i = 0; i < nc; i++) { if j < midj: if i < midi: out.setElement(j, i, complex_matrix.getElementReference(j + offj, i + offi)) else: out.setElement(j, i, complex_matrix.getElementReference(j + offj, i - midi)) else: if i < midi: out.setElement(j, i, complex_matrix.getElementReference(j - midj, i + offi)) else: out.setElement(j, i, complex_matrix.getElementReference(j - midj, i - midi)) return out def ifft_shift(complex_matrix): nc = complex_matrix.getNcol() nr = complex_matrix.getNrow() out = ComplexMatrix(nr, nc) midi = int(math.ceil(nc / 2.0)) offi = int(math.floor(nc / 2.0)) midj = int(math.ceil(nr / 2.0)) offj = int(math.floor(nr / 2.0)) for j in range(nr): # (int j = 0; j < nr; j ++){ for i in range(nc): # for (int i = 0; i < nc; i++) { if j < midj: if i < midi: out.setElement(j, i, complex_matrix.getElementReference(j + offj, i + offi)) else: out.setElement(j, i, complex_matrix.getElementReference(j + offj, i - midi)) else: if i < midi: out.setElement(j, i, complex_matrix.getElementReference(j - midj, i + offi)) else: out.setElement(j, i, complex_matrix.getElementReference(j - midj, i - midi)) return out; def ifft_shift_real(matrix): nc = matrix.getNcol() nr = matrix.getNrow() out = Matrix (nr, nc) midi = int(math.ceil(nc/2.0)) offi = int(math.floor(nc/2.0)) midj = int(math.ceil(nr/2.0)) offj = int(math.floor(nr/2.0)) for j in range(nr): # for (int j = 0; j < nr; j ++){ for i in range(nc): #for (int i = 0; i < nc; i++) { if j < midj: if i < midi: out.setElement(j, i, matrix.getElement(j+offj, i+offi)) else: out.setElement(j, i, matrix.getElement(j+offj, i-midi)) else: if i < midi: out.setElement(j, i, matrix.getElement(j-midj, i+offi)) else: out.setElement(j, i, matrix.getElement(j-midj, i-midi)) return out # compute 2D fft from an image def fft2(data): h =len(data) w = len(data[0]) fft = DoubleFFT_2D(h, w) fft.realForward(data) return fft_array_2d_to_complex_matrix(data, h, w) # compute complex 2D fft from an image def cfft2(data): h = len(data) w = len(data[0]) fft = DoubleFFT_2D(h, w/2) fft.complexForward(data) return fft_complex_array_2d_to_complex_matrix(data, h, w/2) # compute inverse 2D fft from a complex matrix def ifft2(m): w = m.getNcol() h = m.getNrow() fft = DoubleFFT_2D(h, w) data = complex_matrix_to_fft_array_2d(m) fft.realInverse(data, True) return data # compute complex inverse 2D fft from a complex matrix def cifft2(m): w = m.getNcol() h = m.getNrow() fft = DoubleFFT_2D(h, w) data = new_array('d', h, 2 * w) # new double[h][2*w]; for j in range(h): # for (int j=0; j max: max = m.getElementReference(j, i).abs() realmax = m.getElementReference(j, i).getReal() imagmax = m.getElementReference(j, i).getImag() rmax = j cmax = i res = new_array("d", 5) res[0] = math.sqrt(realmax * realmax + imagmax * imagmax) res[1] = rmax res[2] = cmax res[3] = realmax res[4] = imagmax return res; def sum_square_abs(m): s = 0.0 for j in range(m.getNrow()): # (int j = 0; j < m.getNrow(); j ++): for i in range(m.getNcol()): # for (int i = 0; i < m.getNcol(); i++): s += m.getElementReference(j, i).squareAbs(); return s; def dftups(complex_matrix, nor, noc, roff, coff, usfac): # function out=dftups(in,nor,noc,usfac,roff,coff); # Upsampled DFT by matrix multiplies, can compute an upsampled DFT in justa small region. # usfac Upsampling factor (default usfac = 1) # [nor,noc] Number of pixels in the output upsampled DFT, in # units of upsampled pixels (default = size(in)) # roff, coff Row and column offsets, allow to shift the output array to # a region of interest on the DFT (default = 0) # Recieves DC in upper left corner, image center must be in (1,1) # Loic Le Guyader - Jun 11, 2011 Java version for ImageJ plugin # Manuel Guizar - Dec 13, 2007 # Modified from dftus, by J.R. Fienup 7/31/06 # This code is intended to provide the same result as if the following # operations were performed # - Embed the array "in" in an array that is usfac times larger in each # dimension. ifftshift to bring the center of the image to (1,1). # - Take the FFT of the larger array # - Extract an [nor, noc] region of the result. Starting with the # [roff+1 coff+1] element. # It achieves this result by computing the DFT in the output array without # the need to zeropad. Much faster and memory efficient than the # zero-padded FFT approach if [nor noc] are much smaller than [nr*usfac nc*usfac] nr = complex_matrix.getNrow() nc = complex_matrix.getNcol() # Compute kernels and obtain DFT by matrix products amplitude = -2.0 * math.pi / (nc * usfac) nor,noc=int(nor),int(noc) u = Matrix(nc, 1) for i in range(nc): # (int i = 0; i < nc; i++) { u.setElement(i, 0, i - math.floor(nc / 2.0)) u = ifft_shift_real(u) v = Matrix(1, noc) for i in range(noc): # for (int i = 0; i < noc; i++) { v.setElement(0, i, i - coff) phase = u.times(v) kernc = ComplexMatrix(nc, noc) for j in range(nc): # for (int j = 0; j < nc; j++) { for i in range(noc): # for (int i = 0; i < noc; i++) { t = Complex() t.polar(1.0, amplitude * phase.getElement(j, i)); kernc.setElement(j, i, t) # ComplexMatrixPrint(kernc) amplitude = -2.0 * math.pi / (nr * usfac) w = Matrix(nor, 1) for i in range(nor): # for (int i = 0; i < nor; i++) { w.setElement(i, 0, i - roff) x = Matrix(1, nr) for i in range(nr): # for (int i = 0; i < nr; i++) { x.setElement(0, i, i - math.floor(nr / 2.0)) x = ifft_shift_real(x) nphase = w.times(x); kernr = ComplexMatrix(nor, nr) for j in range(nor): # for (int j = 0; j < nor; j++) { for i in range(nr): # for (int i = 0; i < nr; i++) { t = Complex(); t.polar(1.0, amplitude * nphase.getElement(j, i)) kernr.setElement(j, i, t) # ComplexMatrixPrint(kernr); return kernr.times(complex_matrix.times(kernc)) def dft_registration(ref, drifted, usfac): m = ref.getNrow() n = ref.getNcol() output = new_array('d', 4) # new double[4] # First upsample by a factor of 2 to obtain initial estimate # Embed Fourier data in a 2x larger array mlarge = m * 2 nlarge = n * 2 large = ComplexMatrix(mlarge, nlarge) c = fft_shift(element_product(ref, drifted.conjugate())) for j in range(m): # (int j = 0; j < m; j++): for i in range(n): # (int i = 0; i < n; i++): large.setElement(int(j + m - math.floor(m / 2.0)), int(i + n - math.floor(n / 2.0)), c.getElementReference(j, i)) # Compute crosscorrelation and locate the peak CC = cifft2(ifft_shift(large)); peak = c_find_peak(CC); # max, r, c, max_r, max_c # Obtain shift in original pixel grid from the position of the # crosscorrelation peak if peak[1] > m: peak[1] = peak[1] - mlarge; if peak[2] > n: peak[2] = peak[2] - nlarge; # If upsampling > 2, then refine estimate with matrix multiply DFT if usfac > 2: # DFT computation # Initial shift estimate in upsampled grid row_shift = round(peak[1] / 2.0 * usfac) / usfac col_shift = round(peak[2] / 2.0 * usfac) / usfac dftshift = math.floor(math.ceil(usfac * 1.5) / 2) # Center of output array at dftshift+1 # Matrix multiply DFT around the current shift estimate cm = element_product(drifted, ref.conjugate()) nCC = dftups(cm, math.ceil(usfac * 1.5), math.ceil(usfac * 1.5), \ dftshift - row_shift * usfac, dftshift - col_shift * usfac, usfac) nCC = nCC.times(1.0 / (m * n * usfac * usfac)).conjugate() # Locate maximum and map back to original pixel grid npeak = c_find_peak(nCC) # max_r, max_i, r, c mrg00 = dftups(element_product(ref, ref.conjugate()), 1, 1, 0, 0, usfac) rg00 = mrg00.getElementReference(0, 0).abs() / (m * n * usfac * usfac) mrf00 = dftups(element_product(drifted, drifted.conjugate()), 1, 1, 0, 0, usfac) rf00 = mrf00.getElementReference(0, 0).abs() / (m * n * usfac * usfac) npeak[1] = npeak[1] - dftshift npeak[2] = npeak[2] - dftshift output[0] = math.sqrt(abs(1.0 - npeak[0] * npeak[0] / (rg00 * rf00))) # error output[1] = math.atan2(npeak[4], npeak[3]) # diffphase output[2] = row_shift + npeak[1] / usfac # delta row output[3] = col_shift + npeak[2] / usfac # delta col else: # If upsampling = 2, no additional pixel shift refinement rg00 = sum_square_abs(ref) / (mlarge * nlarge) rf00 = sum_square_abs(drifted) / (mlarge * nlarge) output[0] = math.sqrt(abs(1.0 - peak[0] * peak[0] / (rg00 * rf00))) # error output[1] = math.atan2(peak[4], peak[3]) # diffphase output[2] = peak[1] / 2.0 # delta row output[3] = peak[2] / 2.0 # delta col return output def calculate_shifts(imp_r, imp_i, roi, upscale_factor=100, reference_slide=1, java_code=False): if roi is None or roi.bounds.minX <0 or roi.bounds.minY<0 or roi.bounds.maxX>=imp_r.width or roi.bounds.maxY>=imp_r.height: raise Exception("Invalid roi: " + str(roi)) if java_code: get_context().getPluginManager().loadInitializePlugin("Align_ComputeShifts2.java") compute_shifts_filter = get_context().getClassByName("Align_ComputeShifts2").newInstance() compute_shifts_filter.setup(upscale_factor, False, imp_r, imp_i, 1, roi) compute_shifts_filter.run(None) return compute_shifts_filter.shifts IJ.showStatus("1/2 Perform FFT of each slice") ffts = compute_fft(imp_r, imp_i, roi) # calculate shifts IJ.showStatus("2/2 Calculate shifts between slices") shifts = new_array('d', len(ffts), 6) # new double[ffts.length][6]; for i in range(len(ffts)): # (int i = 0; i < ffts.length; i++): shifts[i][0] = reference_slide shifts[i][1] = i + 1 temp = dft_registration(ffts[reference_slide - 1], ffts[i], upscale_factor) shifts[i][2] = temp[2] shifts[i][3] = temp[3] shifts[i][4] = temp[0] shifts[i][5] = temp[1] IJ.showProgress(i + 1, len(ffts))\ return shifts # [ref, drifted, dr, dc, error, diffphase] def to_ip(obj): if is_string(obj): obj = open_image(obj) else: if type(obj) == Data: obj = obj.toBufferedImage(False) if type(obj) == BufferedImage: obj = load_image(obj) return obj def calculate_shift(ref,img, roi, g_sigma=3.0, upscale_factor=100): ref = to_ip(ref) img = to_ip(img) stack = create_stack([ref,img]) ipr, ipi = complex_edge_filtering(stack, g_sigma=g_sigma, show=False) shifts = calculate_shifts(ipr, ipi, roi, upscale_factor=upscale_factor, java_code=True) xoff, yoff = shifts[1][3], shifts[1][2] error, diffphase = shifts[1][4], shifts[1][5] return xoff, yoff, error, diffphase roi=Roi(256,0,128,128) stack = load_test_stack(show=False, size=9) ipr, ipi = complex_edge_filtering(stack, show=False) shifts = calculate_shifts(ipr, ipi, roi, java_code=True) #shifts= load_shifts("{images}/TestObjAligner/shifts.mat") #stack = load_test_stack(show=True) r=translate(stack, shifts, show=True)