642 lines
26 KiB
Python
642 lines
26 KiB
Python
from ijutils import *
|
|
import java.lang.reflect
|
|
import flanagan.complex.ComplexMatrix as ComplexMatrix
|
|
import flanagan.math.Matrix as Matrix
|
|
import flanagan.complex.Complex as Complex
|
|
import org.jtransforms.fft.DoubleFFT_2D as DoubleFFT_2D
|
|
import math
|
|
from startup import ScriptUtils
|
|
import ij.plugin.filter.PlugInFilterRunner as PlugInFilterRunner
|
|
import ij.plugin.filter.ExtendedPlugInFilter as ExtendedPlugInFilter
|
|
import ij.plugin.filter.ExtendedPlugInFilter as ExtendedPlugInFilter
|
|
import java.lang.Thread as Thread
|
|
|
|
|
|
def new_array(type, *dimensions):
|
|
return java.lang.reflect.Array.newInstance(ScriptUtils.getPrimitiveType(type), *dimensions)
|
|
|
|
def load_stack(title, file_list, show=False):
|
|
ip_list = []
|
|
for f in file_list:
|
|
ip_list.append(open_image(expand_path(f)))
|
|
stack = create_stack(ip_list, title=title)
|
|
if show:
|
|
stack.show()
|
|
return stack
|
|
|
|
def load_test_stack(title="Test", show=False, size=9):
|
|
file_list = []
|
|
for index in range(40, 40+size):
|
|
file_list.append("{images}/TestObjAligner/i210517_0" + str(index) + "#001.tif")
|
|
return load_stack(title, file_list, show)
|
|
|
|
def load_corr_stack(title="Corr", show=False):
|
|
file_list = []
|
|
for index in range(40, 49):
|
|
file_list.append("{images}/TestObjAligner_corr/i210517_0" + str(index) + "#001.tif")
|
|
return load_stack(title, file_list, show)
|
|
|
|
def complex_edge_filtering(imp, complex=True, g_sigma=3.0, g_resolution=1e-4, show=False, java_code=False):
|
|
if java_code:
|
|
get_context().getPluginManager().loadInitializePlugin("Align_ComplexEdgeFiltering.java")
|
|
complex_edge_filter = get_context().getClassByName("Align_ComplexEdgeFiltering").newInstance()
|
|
complex_edge_filter.setup(str(g_sigma)+","+str(complex)+","+str(show), imp) #Gaussian blur radius, Complex (True) or Real (False), show dialog = False
|
|
complex_edge_filter.run(imp.getProcessor())
|
|
return complex_edge_filter.output
|
|
|
|
gb = GaussianBlur()
|
|
sobel_r = [1, 0, -1, 2, 0, -2, 1, 0, -1]
|
|
sobel_i = [1, 2, 1, 0, 0, 0, -1, -2, -1]
|
|
|
|
imp_r = imp.createImagePlus()
|
|
stack_r = ImageStack(imp.getWidth(), imp.getHeight())
|
|
|
|
if (complex):
|
|
imp_i = imp.createImagePlus()
|
|
stack_i = ImageStack(imp.getWidth(), imp.getHeight())
|
|
for i in range(1, imp.getImageStackSize() + 1):
|
|
ip_r = imp.getStack().getProcessor(i).duplicate().convertToFloat()
|
|
# Gaussian blurring
|
|
gb.blurGaussian(ip_r, g_sigma, g_sigma, g_resolution)
|
|
ip_i = ip_r.duplicate()
|
|
# Sobel edge filtering
|
|
ip_r.convolve3x3(sobel_r)
|
|
ip_i.convolve3x3(sobel_i)
|
|
|
|
stack_r.addSlice(imp.getStack().getSliceLabel(i), ip_r)
|
|
stack_i.addSlice(imp.getStack().getSliceLabel(i), ip_i)
|
|
IJ.showProgress(i, imp.getImageStackSize())
|
|
|
|
# imag
|
|
imp_i.setStack("EdgeImag_" + imp.getTitle(), stack_i);
|
|
imp_i.resetDisplayRange()
|
|
if show:
|
|
imp_i.show()
|
|
imp_i.updateAndDraw()
|
|
else:
|
|
imp_i = None
|
|
for i in range(1, imp.getImageStackSize() + 1):
|
|
ip_r = imp.getStack().getProcessor(i).duplicate().convertToFloat()
|
|
# Gaussian blurring
|
|
gb.blurGaussian(ip_r, g_sigma, g_sigma, g_resolution)
|
|
# Sobel edge filtering
|
|
ip_r.filter(ImageProcessor.FIND_EDGES)
|
|
stack_r.addSlice(imp.getStack().getSliceLabel(i), ip_r)
|
|
IJ.showProgress(i, imp.getImageStackSize())
|
|
|
|
# real
|
|
imp_r.setStack("EdgeReal_" + imp.getTitle(), stack_r)
|
|
imp_r.resetDisplayRange()
|
|
if show:
|
|
imp_r.show()
|
|
imp_r.updateAndDraw()
|
|
return [imp_r, imp_i]
|
|
|
|
class TranslationFilter(ExtendedPlugInFilter):
|
|
def __init__(self):
|
|
self.shifts=None
|
|
self.flags = (self.DOES_ALL-self.DOES_RGB)|self.DOES_STACKS|self.NO_CHANGES|self.FINAL_PROCESSING
|
|
self.imp=None
|
|
self.output = None
|
|
self.translated = None
|
|
self.pifr = None
|
|
self.nbslices = 0
|
|
self.processed = 0
|
|
|
|
def setup(self, arg, imp):
|
|
if "final"==arg:
|
|
self.output.setStack("REG_" + self.imp.getTitle(), self.translated)
|
|
return self.DONE
|
|
else:
|
|
if self.imp is None:
|
|
self.imp = imp;
|
|
return self.flags;
|
|
|
|
def showDialog(self,imp, command, pfr):
|
|
self.pifr = pfr
|
|
return flags
|
|
|
|
# Called by ImageJ to set the number of calls to run(ip) corresponding to 100% of the progress bar
|
|
def setNPasses(self, nPasses):
|
|
self.nbslices = nPasses;
|
|
self.output = self.imp.createImagePlus();
|
|
self.translated = ImageStack(self.imp.getWidth(), self.imp.getHeight(), self.nbslices)
|
|
|
|
#Process a FloatProcessor (with the CONVERT_TO_FLOAT flag, ImageJ does the conversion to float).
|
|
# Called by ImageJ for each stack slice (when processing a full stack); for RGB also called once for each color. */
|
|
def run(self, ip):
|
|
if Thread.currentThread().isInterrupted():
|
|
return
|
|
thisone = self.pifr.getSliceNumber()
|
|
|
|
nip = ip.duplicate().convertToFloat()
|
|
nip.setInterpolationMethod(ImageProcessor.BICUBIC)
|
|
if len(self.shifts) != self.nbslices:
|
|
xoff, yoff = self.shifts[1][3], self.shifts[1][2] # translate all the frame by the same shifts
|
|
else:
|
|
xoff, yoff = self.shifts[thisone-1][3], self.shifts[thisone-1][2]
|
|
nip.translate(xoff, yoff)
|
|
|
|
lbl = self.imp.getStack().getSliceLabel(thisone)
|
|
if lbl != None:
|
|
self.translated.addSlice(lbl, nip, thisone - 1)
|
|
else:
|
|
self.translated.addSlice("" + thisone, nip, thisone - 1)
|
|
|
|
self.translated.deleteSlice(thisone + 1)
|
|
|
|
self.processed+=1
|
|
IJ.showProgress(self.processed, self.nbslices);
|
|
|
|
def translate(stack, shifts, show=False, java_code=False):
|
|
WindowManager.setTempCurrentImage(stack)
|
|
if java_code:
|
|
get_context().getPluginManager().loadInitializePlugin("Align_TranslationFilter.java")
|
|
translation_filter = get_context().getClassByName("Align_TranslationFilter").newInstance()
|
|
translation_filter.imp = imp
|
|
translation_filter.shifts = shifts
|
|
pfr = PlugInFilterRunner(translation_filter, "", "" )
|
|
ret = translation_filter.output
|
|
else:
|
|
translation_filter = TranslationFilter()
|
|
translation_filter.shifts = shifts
|
|
translation_filter.imp = stack
|
|
pfr = PlugInFilterRunner(translation_filter, "", "" )
|
|
ret = translation_filter.output
|
|
if show:
|
|
ret.show()
|
|
ret.updateAndDraw()
|
|
return ret
|
|
|
|
|
|
def load_shifts(filename):
|
|
get_context().getPluginManager().loadInitializePlugin("ShiftsIO.java")
|
|
sio = get_context().getClassByName("ShiftsIO").newInstance()
|
|
return sio.load(expand_path(filename), "directshifts")
|
|
|
|
def save_shifts(filename, shifts):
|
|
get_context().getPluginManager().loadInitializePlugin("ShiftsIO.java")
|
|
sio = get_context().getClassByName("ShiftsIO").newInstance()
|
|
sio.save(expand_path(filename), shifts, "directshifts")
|
|
|
|
|
|
def ip_to_fft_array_2d(ip):
|
|
pixels = ip.getPixels()
|
|
w = ip.getWidth()
|
|
h = ip.getHeight()
|
|
data = new_array('d', h, w) # new double[h][w]
|
|
for j in range(h): # (int j = 0; j < h; j++)
|
|
for i in range(w): # for (int i = 0; i < w; i++)
|
|
data[j][i] = pixels[j * w + i]
|
|
return data
|
|
|
|
|
|
def ip_to_fft_complex_array_2d(ip_r, ip_i):
|
|
pixels_r = ip_r.getPixels()
|
|
pixels_i = ip_i.getPixels()
|
|
w = ip_r.getWidth()
|
|
h = ip_r.getHeight()
|
|
data = new_array('d', h, 2 * w) # new double[h][2*w];
|
|
for j in range(h): # (int j = 0; j < h; j++)
|
|
for i in range(w): # for (int i = 0; i < w; i++)
|
|
data[j][2 * i] = pixels_r[j * w + i]
|
|
data[j][2 * i + 1] = pixels_i[j * w + i];
|
|
return data
|
|
|
|
def fft_array_2d_to_complex_matrix(data, h, w):
|
|
m = ComplexMatrix(h,w)
|
|
for j in range(h): #for (int j = 0; j < h; j++) {
|
|
for i in range(w/2): # for (int i = 0; i <= w/2; i++) {
|
|
if (j > 0) and (i > 0) and (i < w/2):
|
|
m.setElement(j, i, Complex(data[j][2*i], data[j][2*i+1]))
|
|
m.setElement(h-j, w-i, Complex(data[j][2*i], -data[j][2*i+1]))
|
|
if (j == 0) and (i > 0) and (i < w/2):
|
|
m.setElement(0, i, Complex(data[0][2*i], data[0][2*i+1]))
|
|
m.setElement(0, w-i, Complex(data[0][2*i], -data[0][2*i+1]))
|
|
if (i == 0) and (j > 0) and (j < h/2):
|
|
m.setElement(j,0, Complex(data[j][0], data[j][1]))
|
|
m.setElement(h-j, 0, Complex(data[j][0], -data[j][1]))
|
|
m.setElement(j, w/2, Complex(data[h-j][1], -data[h-j][0]))
|
|
m.setElement(h-j, w/2, Complex(data[h-j][1], data[h-j][0]))
|
|
if (j == 0) and (i == 0):
|
|
m.setElement(0, 0, Complex(data[0][0], 0));
|
|
if (j == 0) and (i == w/2):
|
|
m.setElement(0, w/2, Complex(data[0][1], 0));
|
|
if (j == h/2) and (i == 0):
|
|
m.setElement(h/2, 0, Complex(data[h/2][0], 0));
|
|
if (j == h/2) and (i == w/2):
|
|
m.setElement(h/2, w/2, Complex(data[h/2][1], 0));
|
|
return m
|
|
|
|
|
|
def fft_complex_array_2d_to_complex_matrix(data, h, w):
|
|
m = ComplexMatrix(h,w);
|
|
for j in range(h): #for (int j = 0; j < h; j++) {
|
|
for i in range(w): # for (int i = 0; i < w; i++) {
|
|
m.setElement(j,i, Complex(data[j][2*i], data[j][2*i+1]))
|
|
return m
|
|
|
|
def complex_matrix_to_fft_array_2d(m):
|
|
w = m.getNcol()
|
|
h = m.getNrow()
|
|
data = new_array('d', h,w) #new double[h][w];
|
|
for j in range(h): #for (int j = 0; j < h; j++) {
|
|
for i in range(w): #for (int i = 0; i <= w/2; i++) {
|
|
if (j > 0) and (i > 0) and (i < w/2):
|
|
data[j][2*i] = m.getElementReference(j,i).getReal()
|
|
data[j][2*i+1] = m.getElementReference(j,i).getImag()
|
|
if (j == 0) and (i > 0) and (i < w/2):
|
|
data[0][2*i] = m.getElementReference(0,i).getReal()
|
|
data[0][2*i+1] = m.getEementReference(0,i).getImag()
|
|
if (i == 0) and (j > 0) and (j < h/2):
|
|
data[j][0] = m.getElementReference(j,0).getReal()
|
|
data[j][1] = m.getElementReference(j,0).getImag()
|
|
data[h-j][1] = m.getElementReference(j,w/2).getReal()
|
|
data[h-j][0] = m.getElementReference(h-j,w/2).getImag()
|
|
if (j == 0) and (i == 0):
|
|
data[0][0] = m.getElementReference(0,0).getReal()
|
|
if (j == 0) and (i == w/2):
|
|
data[0][1] = m.getElementReference(0,w/2).getReal()
|
|
if (j == h/2) and (i == 0):
|
|
data[h/2][0] = m.getElementReference(h/2,0).getReal()
|
|
if (j == h/2) and ( i == w/2):
|
|
data[h/2][1] = m.getElementReference(h/2,w/2).getReal()
|
|
return data
|
|
|
|
|
|
# convert a Complex Matrix into an 2d real part array data[0][][] and 2d imaginary part data[1][][]
|
|
def complex_matrix_to_real_array_2d(m):
|
|
w = m.getNcol()
|
|
h = m.getNrow()
|
|
data = new_array('d', 2,h,w) #new double[2][h][w];
|
|
for j in range(h): #for (int j = 0; j < h; j++) {
|
|
for i in range(w): #for (int i = 0; i < w; i++) {
|
|
data[0][j][i] = m.getElementReference(j,i).getReal()
|
|
data[1][j][i] = m.getElementReference(j,i).getImag()
|
|
return data;
|
|
|
|
|
|
def compute_fft(imp_r, imp_i, roi):
|
|
slices = imp_r.getStackSize()
|
|
ffts = java.lang.reflect.Array.newInstance(ComplexMatrix, slices) # new ComplexMatrix[slices]
|
|
for i in range(1, slices + 1):
|
|
if imp_i is None:
|
|
ip = imp_r.getStack().getProcessor(i)
|
|
ip.setRoi(roi)
|
|
curr = ip.crop().convertToFloat();
|
|
data = ip_to_fft_array_2d(curr)
|
|
ffts[i - 1] = fft2(data)
|
|
else:
|
|
ip1 = imp_r.getStack().getProcessor(i)
|
|
ip1.setRoi(roi)
|
|
curr_r = ip1.crop().convertToFloat()
|
|
ip2 = imp_i.getStack().getProcessor(i)
|
|
ip2.setRoi(roi)
|
|
curr_i = ip2.crop().convertToFloat()
|
|
data = ip_to_fft_complex_array_2d(curr_r, curr_i)
|
|
ffts[i - 1] = cfft2(data)
|
|
IJ.showProgress(i, slices)
|
|
return ffts
|
|
|
|
|
|
def element_product(a, b):
|
|
nr = a.getNrow()
|
|
nc = a.getNcol()
|
|
res = ComplexMatrix(nr, nc)
|
|
for j in range(nr): # (int j = 0; j < nr; j++) {
|
|
for i in range(nc): # (int i = 0; i < nc; i++) {
|
|
res.setElement(j, i, a.getElementReference(j, i).times(b.getElementReference(j, i)))
|
|
return res;
|
|
|
|
|
|
def fft_shift(complex_matrix):
|
|
nc = complex_matrix.getNcol()
|
|
nr = complex_matrix.getNrow()
|
|
out = ComplexMatrix(nr, nc)
|
|
midi = int(math.floor(nc / 2.0))
|
|
offi = int(math.ceil(nc / 2.0))
|
|
midj = int(math.floor(nr / 2.0))
|
|
offj = int(math.ceil(nr / 2.0))
|
|
for j in range(nr): # for (int j = 0; j < nr; j ++){
|
|
for i in range(nc): # for (int i = 0; i < nc; i++) {
|
|
if j < midj:
|
|
if i < midi:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j + offj, i + offi))
|
|
else:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j + offj, i - midi))
|
|
else:
|
|
if i < midi:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j - midj, i + offi))
|
|
else:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j - midj, i - midi))
|
|
return out
|
|
|
|
|
|
def ifft_shift(complex_matrix):
|
|
nc = complex_matrix.getNcol()
|
|
nr = complex_matrix.getNrow()
|
|
out = ComplexMatrix(nr, nc)
|
|
midi = int(math.ceil(nc / 2.0))
|
|
offi = int(math.floor(nc / 2.0))
|
|
midj = int(math.ceil(nr / 2.0))
|
|
offj = int(math.floor(nr / 2.0))
|
|
|
|
for j in range(nr): # (int j = 0; j < nr; j ++){
|
|
for i in range(nc): # for (int i = 0; i < nc; i++) {
|
|
if j < midj:
|
|
if i < midi:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j + offj, i + offi))
|
|
else:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j + offj, i - midi))
|
|
else:
|
|
if i < midi:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j - midj, i + offi))
|
|
else:
|
|
out.setElement(j, i, complex_matrix.getElementReference(j - midj, i - midi))
|
|
return out;
|
|
|
|
|
|
def ifft_shift_real(matrix):
|
|
nc = matrix.getNcol()
|
|
nr = matrix.getNrow()
|
|
out = Matrix (nr, nc)
|
|
|
|
midi = int(math.ceil(nc/2.0))
|
|
offi = int(math.floor(nc/2.0))
|
|
midj = int(math.ceil(nr/2.0))
|
|
offj = int(math.floor(nr/2.0))
|
|
|
|
for j in range(nr): # for (int j = 0; j < nr; j ++){
|
|
for i in range(nc): #for (int i = 0; i < nc; i++) {
|
|
if j < midj:
|
|
if i < midi:
|
|
out.setElement(j, i, matrix.getElement(j+offj, i+offi))
|
|
else:
|
|
out.setElement(j, i, matrix.getElement(j+offj, i-midi))
|
|
else:
|
|
if i < midi:
|
|
out.setElement(j, i, matrix.getElement(j-midj, i+offi))
|
|
else:
|
|
out.setElement(j, i, matrix.getElement(j-midj, i-midi))
|
|
return out
|
|
|
|
|
|
|
|
# compute 2D fft from an image
|
|
def fft2(data):
|
|
h =len(data)
|
|
w = len(data[0])
|
|
fft = DoubleFFT_2D(h, w)
|
|
fft.realForward(data)
|
|
return fft_array_2d_to_complex_matrix(data, h, w)
|
|
|
|
# compute complex 2D fft from an image
|
|
def cfft2(data):
|
|
h = len(data)
|
|
w = len(data[0])
|
|
fft = DoubleFFT_2D(h, w/2)
|
|
fft.complexForward(data)
|
|
return fft_complex_array_2d_to_complex_matrix(data, h, w/2)
|
|
# compute inverse 2D fft from a complex matrix
|
|
def ifft2(m):
|
|
w = m.getNcol()
|
|
h = m.getNrow()
|
|
fft = DoubleFFT_2D(h, w)
|
|
data = complex_matrix_to_fft_array_2d(m)
|
|
fft.realInverse(data, True)
|
|
return data
|
|
|
|
# compute complex inverse 2D fft from a complex matrix
|
|
def cifft2(m):
|
|
w = m.getNcol()
|
|
h = m.getNrow()
|
|
fft = DoubleFFT_2D(h, w)
|
|
data = new_array('d', h, 2 * w) # new double[h][2*w];
|
|
for j in range(h): # for (int j=0; j<h; j++):
|
|
for i in range(w): # for (int i=0; i<w; i++) {
|
|
data[j][2 * i] = m.getElementReference(j, i).getReal()
|
|
data[j][2 * i + 1] = m.getElementReference(j, i).getImag()
|
|
fft.complexInverse(data, True)
|
|
out = ComplexMatrix(h, w)
|
|
for j in range(h): # (int j=0; j<h; j++) {
|
|
for i in range(w): # (int i=0; i<w; i++) {
|
|
out.setElement(j, i, data[j][2 * i], data[j][2 * i + 1])
|
|
return out;
|
|
|
|
|
|
def c_find_peak(m):
|
|
max = 0.0
|
|
realmax = 0.0
|
|
imagmax = 0.0
|
|
cmax = 0
|
|
rmax = 0
|
|
for j in range(m.getNrow()): # (int j = 0; j < m.getNrow(); j ++){
|
|
for i in range(m.getNcol()): # for (int i = 0; i < m.getNcol(); i++) {
|
|
if m.getElementReference(j, i).abs() > max:
|
|
max = m.getElementReference(j, i).abs()
|
|
realmax = m.getElementReference(j, i).getReal()
|
|
imagmax = m.getElementReference(j, i).getImag()
|
|
rmax = j
|
|
cmax = i
|
|
res = new_array("d", 5)
|
|
res[0] = math.sqrt(realmax * realmax + imagmax * imagmax)
|
|
res[1] = rmax
|
|
res[2] = cmax
|
|
res[3] = realmax
|
|
res[4] = imagmax
|
|
return res;
|
|
|
|
|
|
def sum_square_abs(m):
|
|
s = 0.0
|
|
for j in range(m.getNrow()): # (int j = 0; j < m.getNrow(); j ++):
|
|
for i in range(m.getNcol()): # for (int i = 0; i < m.getNcol(); i++):
|
|
s += m.getElementReference(j, i).squareAbs();
|
|
return s;
|
|
|
|
|
|
def dftups(complex_matrix, nor, noc, roff, coff, usfac):
|
|
# function out=dftups(in,nor,noc,usfac,roff,coff);
|
|
# Upsampled DFT by matrix multiplies, can compute an upsampled DFT in justa small region.
|
|
# usfac Upsampling factor (default usfac = 1)
|
|
# [nor,noc] Number of pixels in the output upsampled DFT, in
|
|
# units of upsampled pixels (default = size(in))
|
|
# roff, coff Row and column offsets, allow to shift the output array to
|
|
# a region of interest on the DFT (default = 0)
|
|
# Recieves DC in upper left corner, image center must be in (1,1)
|
|
# Loic Le Guyader - Jun 11, 2011 Java version for ImageJ plugin
|
|
# Manuel Guizar - Dec 13, 2007
|
|
# Modified from dftus, by J.R. Fienup 7/31/06
|
|
|
|
# This code is intended to provide the same result as if the following
|
|
# operations were performed
|
|
# - Embed the array "in" in an array that is usfac times larger in each
|
|
# dimension. ifftshift to bring the center of the image to (1,1).
|
|
# - Take the FFT of the larger array
|
|
# - Extract an [nor, noc] region of the result. Starting with the
|
|
# [roff+1 coff+1] element.
|
|
|
|
# It achieves this result by computing the DFT in the output array without
|
|
# the need to zeropad. Much faster and memory efficient than the
|
|
# zero-padded FFT approach if [nor noc] are much smaller than [nr*usfac nc*usfac]
|
|
|
|
nr = complex_matrix.getNrow()
|
|
nc = complex_matrix.getNcol()
|
|
# Compute kernels and obtain DFT by matrix products
|
|
amplitude = -2.0 * math.pi / (nc * usfac)
|
|
nor,noc=int(nor),int(noc)
|
|
u = Matrix(nc, 1)
|
|
for i in range(nc): # (int i = 0; i < nc; i++) {
|
|
u.setElement(i, 0, i - math.floor(nc / 2.0))
|
|
u = ifft_shift_real(u)
|
|
|
|
v = Matrix(1, noc)
|
|
for i in range(noc): # for (int i = 0; i < noc; i++) {
|
|
v.setElement(0, i, i - coff)
|
|
|
|
phase = u.times(v)
|
|
kernc = ComplexMatrix(nc, noc)
|
|
for j in range(nc): # for (int j = 0; j < nc; j++) {
|
|
for i in range(noc): # for (int i = 0; i < noc; i++) {
|
|
t = Complex()
|
|
t.polar(1.0, amplitude * phase.getElement(j, i));
|
|
kernc.setElement(j, i, t)
|
|
|
|
# ComplexMatrixPrint(kernc)
|
|
amplitude = -2.0 * math.pi / (nr * usfac)
|
|
|
|
w = Matrix(nor, 1)
|
|
for i in range(nor): # for (int i = 0; i < nor; i++) {
|
|
w.setElement(i, 0, i - roff)
|
|
|
|
x = Matrix(1, nr)
|
|
for i in range(nr): # for (int i = 0; i < nr; i++) {
|
|
x.setElement(0, i, i - math.floor(nr / 2.0))
|
|
x = ifft_shift_real(x)
|
|
|
|
nphase = w.times(x);
|
|
kernr = ComplexMatrix(nor, nr)
|
|
for j in range(nor): # for (int j = 0; j < nor; j++) {
|
|
for i in range(nr): # for (int i = 0; i < nr; i++) {
|
|
t = Complex();
|
|
t.polar(1.0, amplitude * nphase.getElement(j, i))
|
|
kernr.setElement(j, i, t)
|
|
# ComplexMatrixPrint(kernr);
|
|
return kernr.times(complex_matrix.times(kernc))
|
|
|
|
def dft_registration(ref, drifted, usfac):
|
|
m = ref.getNrow()
|
|
n = ref.getNcol()
|
|
output = new_array('d', 4) # new double[4]
|
|
|
|
# First upsample by a factor of 2 to obtain initial estimate
|
|
# Embed Fourier data in a 2x larger array
|
|
mlarge = m * 2
|
|
nlarge = n * 2
|
|
large = ComplexMatrix(mlarge, nlarge)
|
|
c = fft_shift(element_product(ref, drifted.conjugate()))
|
|
|
|
for j in range(m): # (int j = 0; j < m; j++):
|
|
for i in range(n): # (int i = 0; i < n; i++):
|
|
large.setElement(int(j + m - math.floor(m / 2.0)), int(i + n - math.floor(n / 2.0)), c.getElementReference(j, i))
|
|
|
|
# Compute crosscorrelation and locate the peak
|
|
CC = cifft2(ifft_shift(large));
|
|
peak = c_find_peak(CC); # max, r, c, max_r, max_c
|
|
# Obtain shift in original pixel grid from the position of the
|
|
# crosscorrelation peak
|
|
if peak[1] > m:
|
|
peak[1] = peak[1] - mlarge;
|
|
if peak[2] > n:
|
|
peak[2] = peak[2] - nlarge;
|
|
# If upsampling > 2, then refine estimate with matrix multiply DFT
|
|
if usfac > 2:
|
|
# DFT computation
|
|
# Initial shift estimate in upsampled grid
|
|
row_shift = round(peak[1] / 2.0 * usfac) / usfac
|
|
col_shift = round(peak[2] / 2.0 * usfac) / usfac
|
|
dftshift = math.floor(math.ceil(usfac * 1.5) / 2) # Center of output array at dftshift+1
|
|
# Matrix multiply DFT around the current shift estimate
|
|
cm = element_product(drifted, ref.conjugate())
|
|
nCC = dftups(cm, math.ceil(usfac * 1.5), math.ceil(usfac * 1.5), \
|
|
dftshift - row_shift * usfac, dftshift - col_shift * usfac, usfac)
|
|
nCC = nCC.times(1.0 / (m * n * usfac * usfac)).conjugate()
|
|
# Locate maximum and map back to original pixel grid
|
|
npeak = c_find_peak(nCC) # max_r, max_i, r, c
|
|
mrg00 = dftups(element_product(ref, ref.conjugate()), 1, 1, 0, 0, usfac)
|
|
rg00 = mrg00.getElementReference(0, 0).abs() / (m * n * usfac * usfac)
|
|
mrf00 = dftups(element_product(drifted, drifted.conjugate()), 1, 1, 0, 0, usfac)
|
|
rf00 = mrf00.getElementReference(0, 0).abs() / (m * n * usfac * usfac)
|
|
npeak[1] = npeak[1] - dftshift
|
|
npeak[2] = npeak[2] - dftshift
|
|
output[0] = math.sqrt(abs(1.0 - npeak[0] * npeak[0] / (rg00 * rf00))) # error
|
|
output[1] = math.atan2(npeak[4], npeak[3]) # diffphase
|
|
output[2] = row_shift + npeak[1] / usfac # delta row
|
|
output[3] = col_shift + npeak[2] / usfac # delta col
|
|
else:
|
|
# If upsampling = 2, no additional pixel shift refinement
|
|
rg00 = sum_square_abs(ref) / (mlarge * nlarge)
|
|
rf00 = sum_square_abs(drifted) / (mlarge * nlarge)
|
|
output[0] = math.sqrt(abs(1.0 - peak[0] * peak[0] / (rg00 * rf00))) # error
|
|
output[1] = math.atan2(peak[4], peak[3]) # diffphase
|
|
output[2] = peak[1] / 2.0 # delta row
|
|
output[3] = peak[2] / 2.0 # delta col
|
|
return output
|
|
|
|
|
|
def calculate_shifts(imp_r, imp_i, roi, upscale_factor=100, reference_slide=1, java_code=False):
|
|
if roi is None or roi.bounds.minX <0 or roi.bounds.minY<0 or roi.bounds.maxX>=imp_r.width or roi.bounds.maxY>=imp_r.height:
|
|
raise Exception("Invalid roi: " + str(roi))
|
|
if java_code:
|
|
get_context().getPluginManager().loadInitializePlugin("Align_ComputeShifts2.java")
|
|
compute_shifts_filter = get_context().getClassByName("Align_ComputeShifts2").newInstance()
|
|
compute_shifts_filter.setup(upscale_factor, False, imp_r, imp_i, 1, roi)
|
|
compute_shifts_filter.run(None)
|
|
return compute_shifts_filter.shifts
|
|
|
|
IJ.showStatus("1/2 Perform FFT of each slice")
|
|
ffts = compute_fft(imp_r, imp_i, roi)
|
|
|
|
# calculate shifts
|
|
IJ.showStatus("2/2 Calculate shifts between slices")
|
|
|
|
shifts = new_array('d', len(ffts), 6) # new double[ffts.length][6];
|
|
for i in range(len(ffts)): # (int i = 0; i < ffts.length; i++):
|
|
shifts[i][0] = reference_slide
|
|
shifts[i][1] = i + 1
|
|
temp = dft_registration(ffts[reference_slide - 1], ffts[i], upscale_factor)
|
|
shifts[i][2] = temp[2]
|
|
shifts[i][3] = temp[3]
|
|
shifts[i][4] = temp[0]
|
|
shifts[i][5] = temp[1]
|
|
IJ.showProgress(i + 1, len(ffts))\
|
|
return shifts # [ref, drifted, dr, dc, error, diffphase]
|
|
|
|
def to_ip(obj):
|
|
if is_string(obj):
|
|
obj = open_image(obj)
|
|
else:
|
|
if type(obj) == Data:
|
|
obj = obj.toBufferedImage(False)
|
|
if type(obj) == BufferedImage:
|
|
obj = load_image(obj)
|
|
return obj
|
|
|
|
def calculate_shift(ref,img, roi, g_sigma=3.0, upscale_factor=100):
|
|
ref = to_ip(ref)
|
|
img = to_ip(img)
|
|
stack = create_stack([ref,img])
|
|
ipr, ipi = complex_edge_filtering(stack, g_sigma=g_sigma, show=False)
|
|
shifts = calculate_shifts(ipr, ipi, roi, upscale_factor=upscale_factor, java_code=True)
|
|
xoff, yoff = shifts[1][3], shifts[1][2]
|
|
error, diffphase = shifts[1][4], shifts[1][5]
|
|
return xoff, yoff, error, diffphase
|
|
|
|
|
|
roi=Roi(256,0,128,128)
|
|
stack = load_test_stack(show=False, size=9)
|
|
ipr, ipi = complex_edge_filtering(stack, show=False)
|
|
shifts = calculate_shifts(ipr, ipi, roi, java_code=True)
|
|
#shifts= load_shifts("{images}/TestObjAligner/shifts.mat")
|
|
#stack = load_test_stack(show=True)
|
|
r=translate(stack, shifts, show=True) |