From 58c18d9b8ca62af99e9d20dfe3644d796b5bc3d8 Mon Sep 17 00:00:00 2001 From: Sven Augustin Date: Mon, 5 Aug 2024 18:04:26 +0200 Subject: [PATCH] added algos/utils folder and npmemo --- dap/algos/utils/__init__.py | 4 ++ dap/algos/utils/npmemo.py | 77 +++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 dap/algos/utils/__init__.py create mode 100644 dap/algos/utils/npmemo.py diff --git a/dap/algos/utils/__init__.py b/dap/algos/utils/__init__.py new file mode 100644 index 0000000..9efc8a4 --- /dev/null +++ b/dap/algos/utils/__init__.py @@ -0,0 +1,4 @@ + +from .npmemo import npmemo + + diff --git a/dap/algos/utils/npmemo.py b/dap/algos/utils/npmemo.py new file mode 100644 index 0000000..5ea6670 --- /dev/null +++ b/dap/algos/utils/npmemo.py @@ -0,0 +1,77 @@ +import functools +#import hashlib +import numpy as np + + +def npmemo(func): + """ + numpy array aware memoizer + """ + cache = {} + + @functools.wraps(func) + def wrapper(*args): + key = make_key(args) + try: + return cache[key] + except KeyError: + cache[key] = res = func(*args) + return res + +# wrapper.cache = cache + return wrapper + + +def make_key(args): + return tuple(make_key_entry(i) for i in args) + +def make_key_entry(x): + if isinstance(x, np.ndarray): + return np_array_hash(x) + return x + +def np_array_hash(arr): +# return id(arr) # this has been used so far + res = arr.tobytes() +# res = hashlib.sha256(res).hexdigest() # if tobytes was too large, we could hash it +# res = (arr.shape, res) # tobytes does not take shape into account + return res + + + + + +if __name__ == "__main__": + @npmemo + def expensive(arr, offset): + print("recalc", arr, offset) + return np.dot(arr, arr) + offset + + def test(arr, offset): + print("first") + res1 = expensive(arr, offset) + print("second") + res2 = expensive(arr, offset) + print() + assert np.array_equal(res1, res2) + + arrays = ( + [1, 2, 3], + [1, 2, 3, 4], + [1, 2, 3, 4] + ) + + offsets = ( + 2, + 2, + 5 + ) + + for a, o in zip(arrays, offsets): + a = np.array(a) + test(a, o) + +# print(expensive.cache) + + +