diff --git a/pyzebra/anatric.py b/pyzebra/anatric.py index 60c684d..ddd4e34 100644 --- a/pyzebra/anatric.py +++ b/pyzebra/anatric.py @@ -28,10 +28,16 @@ class AnatricConfig: self.load_from_file(filename) def load_from_file(self, filename): - tree = ET.parse(filename) - self._tree = tree + self._tree = ET.parse(filename) + self._root = self._tree.getroot() - alg_elem = tree.find("Algorithm") + self._alg_elems = dict() + for alg in ALGORITHMS: + self._alg_elems[alg] = ET.Element("Algorithm", attrib={"implementation": alg}) + + self._alg_elems[self.algorithm] = self._tree.find("Algorithm") + + alg_elem = self._tree.find("Algorithm") if self.algorithm == "adaptivemaxcog": self.threshold = float(alg_elem.find("threshold").attrib["value"]) self.shell = float(alg_elem.find("shell").attrib["value"]) @@ -213,7 +219,8 @@ class AnatricConfig: if value not in ALGORITHMS: raise ValueError("Unknown algorithm.") - self._tree.find("Algorithm").attrib["implementation"] = value + self._root.remove(self._tree.find("Algorithm")) + self._root.append(self._alg_elems[value]) def save_as(self, filename): self._tree.write(filename)