Files
DKS/src/AutoTuning/DKSSearchStates.cpp
2016-10-10 14:49:32 +02:00

234 lines
5.5 KiB
C++

#include "DKSSearchStates.h"
/** set the current state so that number of parameters and parameter bounds are known */
DKSSearchStates::DKSSearchStates(Parameters params) {
for (auto p : params) {
State s;
s.value = p.getValue();
s.min = p.min;
s.max = p.max;
s.step = p.step;
current_state_m.push_back(s);
}
neighbour_state_m.resize(current_state_m.size());
best_state_m.resize(current_state_m.size());
best_time_m = std::numeric_limits<double>::max();
next_neighbour_m = -1;
srand(time(NULL));
}
DKSSearchStates::~DKSSearchStates() {
current_state_m.clear();
neighbour_state_m.clear();
best_state_m.clear();
neighbours_m.clear();
}
/** Get all the possible neighbours of the current state */
void DKSSearchStates::getNeighbours(int dist) {
std::vector< std::vector<double> > values;
for (auto state : current_state_m) {
std::vector<double> s;
for (int d = dist; d > 0; d--) {
if (state.value - d*state.step >= state.min)
s.push_back(state.value - state.step);
}
s.push_back(state.value);
for (int d = 1; d < dist + 1; d++) {
if (state.value + d*state.step <= state.max)
s.push_back(state.value + state.step);
}
values.push_back(s);
}
std::vector< std::vector<double> > s {{}};
for (auto& u : values) {
std::vector< std::vector<double> > r;
for(auto& x : s) {
for( auto y : u) {
r.push_back(x);
r.back().push_back(y);
}
}
s.swap(r);
}
//get current state values
std::vector<double> current;
for (auto state : current_state_m)
current.push_back(state.value);
s.erase(std::remove(s.begin(), s.end(), current));
neighbours_m.clear();
neighbours_m = s;
next_neighbour_m = 0;
}
void DKSSearchStates::setCurrentState(std::vector<Parameter> current_state) {
current_state_m.clear();
for (auto& p : current_state) {
State s;
s.value = p.getValue();
s.min = p.min;
s.max = p.max;
s.step = p.step;
current_state_m.push_back(s);
}
}
void DKSSearchStates::setCurrentState(std::vector<State> current_state) {
current_state_m.clear();
for (auto& p : current_state) {
State s;
s.value = p.value;
s.min = p.min;
s.max = p.max;
s.step = p.step;
current_state_m.push_back(s);
}
}
void DKSSearchStates::initCurrentState() {
//go trough parameters in current state and generate a new random value
for (auto& s : current_state_m) {
//get number of total values
int values = (s.max - s.min) / s.step + 1;
int r = rand() % values;
s.value = s.min + r * s.step;
}
getNeighbours();
}
States DKSSearchStates::getCurrentState() {
return current_state_m;
}
States DKSSearchStates::getNextNeighbour() {
//check if there are ant neighbours to move on
if (next_neighbour_m < (int)neighbours_m.size()) {
//get the vector of values for each parameters in the neighbour cell
std::vector<double> neighbour_values = neighbours_m[next_neighbour_m];
//set the values to neighbour_state_m
for (unsigned int n = 0; n < neighbour_state_m.size(); n++)
neighbour_state_m[n].value = neighbour_values[n];
}
next_neighbour_m++;
return neighbour_state_m;
}
States DKSSearchStates::getRandomNeighbour() {
int rand_neighbour = rand() % (int)neighbours_m.size();
//get the vector of values for each parameters in the neighbour cell
std::vector<double> neighbour_values = neighbours_m[rand_neighbour];
//set the values to neighbour_state_m
for (unsigned int n = 0; n < neighbour_state_m.size(); n++)
neighbour_state_m[n].value = neighbour_values[n];
next_neighbour_m = rand_neighbour + 1;
return neighbour_state_m;
}
bool DKSSearchStates::nextNeighbourExists() {
bool neighbourExists = false;
if (next_neighbour_m < (int)neighbours_m.size())
neighbourExists = true;
return neighbourExists;
}
void DKSSearchStates::moveToNeighbour() {
for (unsigned int i = 0; i < current_state_m.size(); i++)
current_state_m[i].value = neighbour_state_m[i].value;
//getNeighbours();
}
void DKSSearchStates::saveCurrentState(double current_time) {
if (current_time < best_time_m) {
for (unsigned int i = 0; i < current_state_m.size(); i++) {
best_state_m[i].value = current_state_m[i].value;
best_state_m[i].min = current_state_m[i].min;
best_state_m[i].max = current_state_m[i].max;
best_state_m[i].step = current_state_m[i].step;
}
best_time_m = current_time;
}
}
void DKSSearchStates::printCurrentState(double time) {
std::cout << "Current state: ";
for (auto s : current_state_m)
std::cout << s.value << "\t";
std::cout << time << std::endl;
}
void DKSSearchStates::printInfo() {
std::cout << "Current state: ";
for (auto s : current_state_m)
std::cout << s.value << "\t";
std::cout << std::endl;
std::cout << "Current neighbour (" << next_neighbour_m << " of " << neighbours_m.size() << "): ";
if (next_neighbour_m > 0) {
for (auto s : neighbour_state_m)
std::cout << s.value << "\t";
}
std::cout << std::endl;
}
void DKSSearchStates::printNeighbour(double time) {
std::cout << "Current neighbour (" << next_neighbour_m << " of " << neighbours_m.size() << "): ";
if (next_neighbour_m > 0) {
for (auto s : neighbour_state_m)
std::cout << s.value << "\t";
}
std::cout << time << std::endl;
}
void DKSSearchStates::printBest() {
std::cout << "Best state (" << best_time_m << "): ";
if (best_time_m > 0) {
for (auto s : best_state_m)
std::cout << s.value << "\t";
}
std::cout << std::endl;
}