From e93909cf7e63fa135c136df4cfa0c2d818045c95 Mon Sep 17 00:00:00 2001 From: Michael Davidsaver Date: Sat, 11 Feb 2023 12:25:44 -0800 Subject: [PATCH] fix shared_array::convertTo() --- src/sharedarray.cpp | 69 +++++++++++++-------------- test/testshared.cpp | 111 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 35 deletions(-) diff --git a/src/sharedarray.cpp b/src/sharedarray.cpp index 76c1453..6e591c5 100644 --- a/src/sharedarray.cpp +++ b/src/sharedarray.cpp @@ -263,13 +263,14 @@ void convertArr(ArrayType dtype, void *dbase, case ArrayType::Bool: convertCast(sbase, dbase, count); return; case ArrayType::Int8: case ArrayType::UInt8: memcpy(dbase, sbase, count*sizeof(int8_t)); return; - // cast sint -> *int always sign extends - case ArrayType::Int16: - case ArrayType::UInt16: convertCast(sbase, dbase, count); return; - case ArrayType::Int32: - case ArrayType::UInt32: convertCast(sbase, dbase, count); return; - case ArrayType::Int64: - case ArrayType::UInt64: convertCast(sbase, dbase, count); return; + // cast sint -> sint extends sign + case ArrayType::Int16: convertCast(sbase, dbase, count); return; + // cast sint -> uint does not extend sign + case ArrayType::UInt16: convertCast(sbase, dbase, count); return; + case ArrayType::Int32: convertCast(sbase, dbase, count); return; + case ArrayType::UInt32: convertCast(sbase, dbase, count); return; + case ArrayType::Int64: convertCast(sbase, dbase, count); return; + case ArrayType::UInt64: convertCast(sbase, dbase, count); return; case ArrayType::Float32:convertCast(sbase, dbase, count); return; case ArrayType::Float64:convertCast(sbase, dbase, count); return; case ArrayType::String: convertToStr(sbase, dbase, count); return; @@ -281,13 +282,13 @@ void convertArr(ArrayType dtype, void *dbase, switch(dtype) { case ArrayType::Bool: convertCast(sbase, dbase, count); return; case ArrayType::Int8: - case ArrayType::UInt8: convertCast(sbase, dbase, count); return; + case ArrayType::UInt8: convertCast(sbase, dbase, count); return; case ArrayType::Int16: case ArrayType::UInt16: memcpy(dbase, sbase, count*sizeof(int16_t)); return; - case ArrayType::Int32: - case ArrayType::UInt32: convertCast(sbase, dbase, count); return; - case ArrayType::Int64: - case ArrayType::UInt64: convertCast(sbase, dbase, count); return; + case ArrayType::Int32: convertCast(sbase, dbase, count); return; + case ArrayType::UInt32: convertCast(sbase, dbase, count); return; + case ArrayType::Int64: convertCast(sbase, dbase, count); return; + case ArrayType::UInt64: convertCast(sbase, dbase, count); return; case ArrayType::Float32:convertCast(sbase, dbase, count); return; case ArrayType::Float64:convertCast(sbase, dbase, count); return; case ArrayType::String: convertToStr(sbase, dbase, count); return; @@ -299,13 +300,13 @@ void convertArr(ArrayType dtype, void *dbase, switch(dtype) { case ArrayType::Bool: convertCast(sbase, dbase, count); return; case ArrayType::Int8: - case ArrayType::UInt8: convertCast(sbase, dbase, count); return; + case ArrayType::UInt8: convertCast(sbase, dbase, count); return; case ArrayType::Int16: - case ArrayType::UInt16: convertCast(sbase, dbase, count); return; + case ArrayType::UInt16: convertCast(sbase, dbase, count); return; case ArrayType::Int32: case ArrayType::UInt32: memcpy(dbase, sbase, count*sizeof(int32_t)); return; - case ArrayType::Int64: - case ArrayType::UInt64: convertCast(sbase, dbase, count); return; + case ArrayType::Int64: convertCast(sbase, dbase, count); return; + case ArrayType::UInt64: convertCast(sbase, dbase, count); return; case ArrayType::Float32:convertCast(sbase, dbase, count); return; case ArrayType::Float64:convertCast(sbase, dbase, count); return; case ArrayType::String: convertToStr(sbase, dbase, count); return; @@ -407,14 +408,14 @@ void convertArr(ArrayType dtype, void *dbase, case ArrayType::Float32: switch(dtype) { case ArrayType::Bool: convertCast(sbase, dbase, count); return; - case ArrayType::Int8: - case ArrayType::UInt8: convertCast(sbase, dbase, count); return; - case ArrayType::Int16: - case ArrayType::UInt16: convertCast(sbase, dbase, count); return; - case ArrayType::Int32: - case ArrayType::UInt32: convertCast(sbase, dbase, count); return; - case ArrayType::Int64: - case ArrayType::UInt64: convertCast(sbase, dbase, count); return; + case ArrayType::Int8: convertCast(sbase, dbase, count); return; + case ArrayType::UInt8: convertCast(sbase, dbase, count); return; + case ArrayType::Int16: convertCast(sbase, dbase, count); return; + case ArrayType::UInt16: convertCast(sbase, dbase, count); return; + case ArrayType::Int32: convertCast(sbase, dbase, count); return; + case ArrayType::UInt32: convertCast(sbase, dbase, count); return; + case ArrayType::Int64: convertCast(sbase, dbase, count); return; + case ArrayType::UInt64: convertCast(sbase, dbase, count); return; case ArrayType::Float32:memcpy(dbase, sbase, count*sizeof(float)); return; case ArrayType::Float64:convertCast(sbase, dbase, count); return; case ArrayType::String: convertToStr(sbase, dbase, count); return; @@ -425,16 +426,16 @@ void convertArr(ArrayType dtype, void *dbase, case ArrayType::Float64: switch(dtype) { case ArrayType::Bool: convertCast(sbase, dbase, count); return; - case ArrayType::Int8: - case ArrayType::UInt8: convertCast(sbase, dbase, count); return; - case ArrayType::Int16: - case ArrayType::UInt16: convertCast(sbase, dbase, count); return; - case ArrayType::Int32: - case ArrayType::UInt32: convertCast(sbase, dbase, count); return; - case ArrayType::Int64: - case ArrayType::UInt64: convertCast(sbase, dbase, count); return; - case ArrayType::Float32:memcpy(dbase, sbase, count*sizeof(double)); return; - case ArrayType::Float64:convertCast(sbase, dbase, count); return; + case ArrayType::Int8: convertCast(sbase, dbase, count); return; + case ArrayType::UInt8: convertCast(sbase, dbase, count); return; + case ArrayType::Int16: convertCast(sbase, dbase, count); return; + case ArrayType::UInt16: convertCast(sbase, dbase, count); return; + case ArrayType::Int32: convertCast(sbase, dbase, count); return; + case ArrayType::UInt32: convertCast(sbase, dbase, count); return; + case ArrayType::Int64: convertCast(sbase, dbase, count); return; + case ArrayType::UInt64: convertCast(sbase, dbase, count); return; + case ArrayType::Float32:convertCast(sbase, dbase, count); return; + case ArrayType::Float64:memcpy(dbase, sbase, count*sizeof(double)); return; case ArrayType::String: convertToStr(sbase, dbase, count); return; case ArrayType::Value: case ArrayType::Null: break; // no convert diff --git a/test/testshared.cpp b/test/testshared.cpp index 7127425..5591863 100644 --- a/test/testshared.cpp +++ b/test/testshared.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -324,12 +325,120 @@ void testElemAlloc() testEq(varr.original_type(), ArrayType::UInt32); } +// round trip conversion when TO can exactly represent all possible values of FROM +template +void testConvertExact() +{ + shared_array inp({ + FROM(0), + FROM(1), + FROM(-1), + std::numeric_limits::min(), + std::numeric_limits::max(), + }); + shared_array expect({ + (TO)FROM(0), + (TO)FROM(1), + (TO)FROM(-1), + (TO)std::numeric_limits::min(), + (TO)std::numeric_limits::max(), + }); + auto conv(inp.template convertTo()); + testShow()<<"Input "< "<(), inp)<<" "<<__func__<<"("< +void testConvertTrunc() +{ + shared_array inp({ + FROM(0), + FROM(1), + FROM(-1), + std::numeric_limits::min(), + std::numeric_limits::max(), + }); + shared_array expect({ + (TO)FROM(0), + (TO)FROM(1), + (TO)FROM(-1), + std::numeric_limits::min(), + std::numeric_limits::max(), + }); + auto conv(inp.template convertTo()); + testShow()<<"Input "< "<::code!=detail::CaptureCode::code, ""); + testDiag("reversible conversions"); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + + testConvertExact(); + testConvertExact(); + testConvertExact(); + testConvertExact(); + + testConvertExact(); + + testConvertExact(); + + testConvertExact(); + + testDiag("integer truncation"); + testConvertTrunc(); + testConvertTrunc(); + testConvertTrunc(); + testConvertTrunc(); + testConvertTrunc(); + testConvertTrunc(); + testArrEq(shared_array({1u, 2u, 0xffffffffu}).convertTo(), shared_array({1u, 2u, 0xffffffffu})); @@ -356,7 +465,7 @@ void testConvert() MAIN(testshared) { - testPlan(155); + testPlan(247); testSetup(); testEmpty(); testEmpty();