snapshot of svn
This commit is contained in:
233
src/AutoTuning/DKSSearchStates.cpp
Normal file
233
src/AutoTuning/DKSSearchStates.cpp
Normal file
@@ -0,0 +1,233 @@
|
||||
#include "DKSSearchStates.h"
|
||||
|
||||
/** set the current state so that number of parameters and parameter bounds are known */
|
||||
DKSSearchStates::DKSSearchStates(Parameters params) {
|
||||
|
||||
for (auto p : params) {
|
||||
State s;
|
||||
s.value = p.getValue();
|
||||
s.min = p.min;
|
||||
s.max = p.max;
|
||||
s.step = p.step;
|
||||
current_state_m.push_back(s);
|
||||
}
|
||||
|
||||
neighbour_state_m.resize(current_state_m.size());
|
||||
best_state_m.resize(current_state_m.size());
|
||||
|
||||
best_time_m = std::numeric_limits<double>::max();
|
||||
|
||||
next_neighbour_m = -1;
|
||||
|
||||
srand(time(NULL));
|
||||
|
||||
}
|
||||
|
||||
DKSSearchStates::~DKSSearchStates() {
|
||||
current_state_m.clear();
|
||||
neighbour_state_m.clear();
|
||||
best_state_m.clear();
|
||||
neighbours_m.clear();
|
||||
}
|
||||
|
||||
/** Get all the possible neighbours of the current state */
|
||||
void DKSSearchStates::getNeighbours(int dist) {
|
||||
|
||||
std::vector< std::vector<double> > values;
|
||||
|
||||
for (auto state : current_state_m) {
|
||||
std::vector<double> s;
|
||||
|
||||
for (int d = dist; d > 0; d--) {
|
||||
if (state.value - d*state.step >= state.min)
|
||||
s.push_back(state.value - state.step);
|
||||
}
|
||||
|
||||
s.push_back(state.value);
|
||||
|
||||
for (int d = 1; d < dist + 1; d++) {
|
||||
if (state.value + d*state.step <= state.max)
|
||||
s.push_back(state.value + state.step);
|
||||
}
|
||||
|
||||
values.push_back(s);
|
||||
}
|
||||
|
||||
|
||||
std::vector< std::vector<double> > s {{}};
|
||||
for (auto& u : values) {
|
||||
std::vector< std::vector<double> > r;
|
||||
for(auto& x : s) {
|
||||
for( auto y : u) {
|
||||
r.push_back(x);
|
||||
r.back().push_back(y);
|
||||
}
|
||||
}
|
||||
s.swap(r);
|
||||
}
|
||||
|
||||
//get current state values
|
||||
std::vector<double> current;
|
||||
for (auto state : current_state_m)
|
||||
current.push_back(state.value);
|
||||
s.erase(std::remove(s.begin(), s.end(), current));
|
||||
|
||||
neighbours_m.clear();
|
||||
neighbours_m = s;
|
||||
next_neighbour_m = 0;
|
||||
}
|
||||
|
||||
void DKSSearchStates::setCurrentState(std::vector<Parameter> current_state) {
|
||||
|
||||
current_state_m.clear();
|
||||
for (auto& p : current_state) {
|
||||
State s;
|
||||
s.value = p.getValue();
|
||||
s.min = p.min;
|
||||
s.max = p.max;
|
||||
s.step = p.step;
|
||||
current_state_m.push_back(s);
|
||||
}
|
||||
}
|
||||
|
||||
void DKSSearchStates::setCurrentState(std::vector<State> current_state) {
|
||||
|
||||
current_state_m.clear();
|
||||
for (auto& p : current_state) {
|
||||
State s;
|
||||
s.value = p.value;
|
||||
s.min = p.min;
|
||||
s.max = p.max;
|
||||
s.step = p.step;
|
||||
current_state_m.push_back(s);
|
||||
}
|
||||
}
|
||||
|
||||
void DKSSearchStates::initCurrentState() {
|
||||
|
||||
//go trough parameters in current state and generate a new random value
|
||||
for (auto& s : current_state_m) {
|
||||
//get number of total values
|
||||
int values = (s.max - s.min) / s.step + 1;
|
||||
|
||||
int r = rand() % values;
|
||||
|
||||
s.value = s.min + r * s.step;
|
||||
}
|
||||
|
||||
getNeighbours();
|
||||
}
|
||||
|
||||
States DKSSearchStates::getCurrentState() {
|
||||
return current_state_m;
|
||||
}
|
||||
|
||||
States DKSSearchStates::getNextNeighbour() {
|
||||
|
||||
//check if there are ant neighbours to move on
|
||||
if (next_neighbour_m < (int)neighbours_m.size()) {
|
||||
|
||||
//get the vector of values for each parameters in the neighbour cell
|
||||
std::vector<double> neighbour_values = neighbours_m[next_neighbour_m];
|
||||
|
||||
//set the values to neighbour_state_m
|
||||
for (unsigned int n = 0; n < neighbour_state_m.size(); n++)
|
||||
neighbour_state_m[n].value = neighbour_values[n];
|
||||
|
||||
}
|
||||
|
||||
next_neighbour_m++;
|
||||
return neighbour_state_m;
|
||||
|
||||
}
|
||||
|
||||
States DKSSearchStates::getRandomNeighbour() {
|
||||
|
||||
int rand_neighbour = rand() % (int)neighbours_m.size();
|
||||
|
||||
//get the vector of values for each parameters in the neighbour cell
|
||||
std::vector<double> neighbour_values = neighbours_m[rand_neighbour];
|
||||
|
||||
//set the values to neighbour_state_m
|
||||
for (unsigned int n = 0; n < neighbour_state_m.size(); n++)
|
||||
neighbour_state_m[n].value = neighbour_values[n];
|
||||
|
||||
next_neighbour_m = rand_neighbour + 1;
|
||||
return neighbour_state_m;
|
||||
|
||||
}
|
||||
|
||||
bool DKSSearchStates::nextNeighbourExists() {
|
||||
bool neighbourExists = false;
|
||||
if (next_neighbour_m < (int)neighbours_m.size())
|
||||
neighbourExists = true;
|
||||
|
||||
return neighbourExists;
|
||||
}
|
||||
|
||||
void DKSSearchStates::moveToNeighbour() {
|
||||
|
||||
for (unsigned int i = 0; i < current_state_m.size(); i++)
|
||||
current_state_m[i].value = neighbour_state_m[i].value;
|
||||
|
||||
//getNeighbours();
|
||||
|
||||
}
|
||||
|
||||
void DKSSearchStates::saveCurrentState(double current_time) {
|
||||
|
||||
if (current_time < best_time_m) {
|
||||
for (unsigned int i = 0; i < current_state_m.size(); i++) {
|
||||
best_state_m[i].value = current_state_m[i].value;
|
||||
best_state_m[i].min = current_state_m[i].min;
|
||||
best_state_m[i].max = current_state_m[i].max;
|
||||
best_state_m[i].step = current_state_m[i].step;
|
||||
}
|
||||
|
||||
best_time_m = current_time;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void DKSSearchStates::printCurrentState(double time) {
|
||||
std::cout << "Current state: ";
|
||||
for (auto s : current_state_m)
|
||||
std::cout << s.value << "\t";
|
||||
std::cout << time << std::endl;
|
||||
|
||||
}
|
||||
|
||||
void DKSSearchStates::printInfo() {
|
||||
|
||||
std::cout << "Current state: ";
|
||||
for (auto s : current_state_m)
|
||||
std::cout << s.value << "\t";
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "Current neighbour (" << next_neighbour_m << " of " << neighbours_m.size() << "): ";
|
||||
if (next_neighbour_m > 0) {
|
||||
for (auto s : neighbour_state_m)
|
||||
std::cout << s.value << "\t";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
}
|
||||
|
||||
void DKSSearchStates::printNeighbour(double time) {
|
||||
std::cout << "Current neighbour (" << next_neighbour_m << " of " << neighbours_m.size() << "): ";
|
||||
if (next_neighbour_m > 0) {
|
||||
for (auto s : neighbour_state_m)
|
||||
std::cout << s.value << "\t";
|
||||
}
|
||||
std::cout << time << std::endl;
|
||||
}
|
||||
|
||||
void DKSSearchStates::printBest() {
|
||||
std::cout << "Best state (" << best_time_m << "): ";
|
||||
if (best_time_m > 0) {
|
||||
for (auto s : best_state_m)
|
||||
std::cout << s.value << "\t";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
Reference in New Issue
Block a user