fixed number of parameters

This commit is contained in:
froejdh_e 2025-04-25 12:00:29 +02:00
parent 78406dc881
commit a871a7e7f6
3 changed files with 13 additions and 13 deletions

View File

@ -1 +1 @@
from ._aare import gaus, pol1 from ._aare import gaus, pol1, scurve, scurve2

View File

@ -321,10 +321,10 @@ n_threads : int, optional
py::array_t<double, py::array::c_style | py::array::forcecast> y_err, py::array_t<double, py::array::c_style | py::array::forcecast> y_err,
int n_threads) { int n_threads) {
if (y.ndim() == 3) { if (y.ndim() == 3) {
auto par = new NDArray<double, 3>({y.shape(0), y.shape(1), 2}); auto par = new NDArray<double, 3>({y.shape(0), y.shape(1), 6});
auto par_err = auto par_err =
new NDArray<double, 3>({y.shape(0), y.shape(1), 2}); new NDArray<double, 3>({y.shape(0), y.shape(1), 6});
auto y_view = make_view_3d(y); auto y_view = make_view_3d(y);
auto y_view_err = make_view_3d(y_err); auto y_view_err = make_view_3d(y_err);
@ -408,10 +408,10 @@ n_threads : int, optional
py::array_t<double, py::array::c_style | py::array::forcecast> y_err, py::array_t<double, py::array::c_style | py::array::forcecast> y_err,
int n_threads) { int n_threads) {
if (y.ndim() == 3) { if (y.ndim() == 3) {
auto par = new NDArray<double, 3>({y.shape(0), y.shape(1), 2}); auto par = new NDArray<double, 3>({y.shape(0), y.shape(1), 6});
auto par_err = auto par_err =
new NDArray<double, 3>({y.shape(0), y.shape(1), 2}); new NDArray<double, 3>({y.shape(0), y.shape(1), 6});
auto y_view = make_view_3d(y); auto y_view = make_view_3d(y);
auto y_view_err = make_view_3d(y_err); auto y_view_err = make_view_3d(y_err);
@ -428,8 +428,8 @@ n_threads : int, optional
} else if (y.ndim() == 1) { } else if (y.ndim() == 1) {
auto par = new NDArray<double, 1>({2}); auto par = new NDArray<double, 1>({6});
auto par_err = new NDArray<double, 1>({2}); auto par_err = new NDArray<double, 1>({6});
auto y_view = make_view_1d(y); auto y_view = make_view_1d(y);
auto y_view_err = make_view_1d(y_err); auto y_view_err = make_view_1d(y_err);

View File

@ -40,7 +40,7 @@ double scurve(const double x, const double * par) {
NDArray<double, 1> scurve(NDView<double, 1> x, NDView<double, 1> par) { NDArray<double, 1> scurve(NDView<double, 1> x, NDView<double, 1> par) {
NDArray<double, 1> y({x.shape()}, 0); NDArray<double, 1> y({x.shape()}, 0);
for (size_t i = 0; i < x.size(); i++) { for (ssize_t i = 0; i < x.size(); i++) {
y(i) = scurve(x(i), par.data()); y(i) = scurve(x(i), par.data());
} }
return y; return y;
@ -52,7 +52,7 @@ double scurve2(const double x, const double * par) {
NDArray<double, 1> scurve2(NDView<double, 1> x, NDView<double, 1> par) { NDArray<double, 1> scurve2(NDView<double, 1> x, NDView<double, 1> par) {
NDArray<double, 1> y({x.shape()}, 0); NDArray<double, 1> y({x.shape()}, 0);
for (size_t i = 0; i < x.size(); i++) { for (ssize_t i = 0; i < x.size(); i++) {
y(i) = scurve2(x(i), par.data()); y(i) = scurve2(x(i), par.data());
} }
return y; return y;
@ -309,7 +309,7 @@ std::array<double, 6> scurve_init_par(const NDView<double, 1> x, const NDView<do
start_par[4] = *ymin + (*ymax - *ymin) / 2; start_par[4] = *ymin + (*ymax - *ymin) / 2;
// Find the first x where the corresponding y value is above the threshold (start_par[4]) // Find the first x where the corresponding y value is above the threshold (start_par[4])
for (size_t i = 0; i < y.size(); ++i) { for (ssize_t i = 0; i < y.size(); ++i) {
if (y[i] >= start_par[4]) { if (y[i] >= start_par[4]) {
start_par[2] = x[i]; start_par[2] = x[i];
break; // Exit the loop after finding the first valid x break; // Exit the loop after finding the first valid x
@ -379,7 +379,7 @@ void fit_scurve(NDView<double, 1> x, NDView<double, 1> y, NDView<double, 1> y_er
// Calculate chi2 // Calculate chi2
chi2 = 0; chi2 = 0;
for (size_t i = 0; i < y.size(); i++) { for (ssize_t i = 0; i < y.size(); i++) {
chi2 += std::pow((y(i) - func::pol1(x(i), par_out.data())) / y_err(i), 2); chi2 += std::pow((y(i) - func::pol1(x(i), par_out.data())) / y_err(i), 2);
} }
} }
@ -421,7 +421,7 @@ std::array<double, 6> scurve2_init_par(const NDView<double, 1> x, const NDView<d
start_par[4] = *ymin + (*ymax - *ymin) / 2; start_par[4] = *ymin + (*ymax - *ymin) / 2;
// Find the first x where the corresponding y value is above the threshold (start_par[4]) // Find the first x where the corresponding y value is above the threshold (start_par[4])
for (size_t i = 0; i < y.size(); ++i) { for (ssize_t i = 0; i < y.size(); ++i) {
if (y[i] <= start_par[4]) { if (y[i] <= start_par[4]) {
start_par[2] = x[i]; start_par[2] = x[i];
break; // Exit the loop after finding the first valid x break; // Exit the loop after finding the first valid x
@ -491,7 +491,7 @@ void fit_scurve2(NDView<double, 1> x, NDView<double, 1> y, NDView<double, 1> y_e
// Calculate chi2 // Calculate chi2
chi2 = 0; chi2 = 0;
for (size_t i = 0; i < y.size(); i++) { for (ssize_t i = 0; i < y.size(); i++) {
chi2 += std::pow((y(i) - func::pol1(x(i), par_out.data())) / y_err(i), 2); chi2 += std::pow((y(i) - func::pol1(x(i), par_out.data())) / y_err(i), 2);
} }
} }