fix bug with CumSum and fix style

This commit is contained in:
VadimPE 2018-09-03 13:52:44 +03:00
parent f8f71f7327
commit 3b38879b11
4 changed files with 36 additions and 89 deletions

View File

@ -17,7 +17,7 @@ void registerFunctionsHigherOrder(FunctionFactory & factory)
factory.registerFunction<FunctionArraySort>();
factory.registerFunction<FunctionArrayReverseSort>();
factory.registerFunction<FunctionArrayCumSum>();
factory.registerFunction<FunctionArrayCumSumLimited>();
factory.registerFunction<FunctionArrayCumSumNonNegative>();
factory.registerFunction<FunctionArrayDifference>();
}

View File

@ -28,7 +28,7 @@ namespace ErrorCodes
* arrayCount(x1,...,xn -> expression, array1,...,arrayn) - for how many elements of the array the expression is true.
* arrayExists(x1,...,xn -> expression, array1,...,arrayn) - is the expression true for at least one array element.
* arrayAll(x1,...,xn -> expression, array1,...,arrayn) - is the expression true for all elements of the array.
* arrayCumSumLimited() - returns an array with cumulative sums of the original. (If value < 0 -> 0).
* arrayCumSumNonNegative() - returns an array with cumulative sums of the original. (If value < 0 -> 0).
* arrayDifference() - returns an array with the difference between all pairs of neighboring elements.
*
* For functions arrayCount, arrayExists, arrayAll, an overload of the form f(array) is available, which works in the same way as f(x -> x, array).
@ -709,6 +709,7 @@ struct ArrayCumSumImpl
struct ArrayDifferenceImpl
{
static bool useDefaultImplementationForConstants() { return true; }
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
@ -716,11 +717,15 @@ struct ArrayDifferenceImpl
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{
if (checkDataType<DataTypeUInt8>(&*expression_return) ||
checkDataType<DataTypeUInt16>(&*expression_return) ||
checkDataType<DataTypeUInt32>(&*expression_return) ||
checkDataType<DataTypeInt8>(&*expression_return))
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt16>());
if (checkDataType<DataTypeUInt16>(&*expression_return) ||
checkDataType<DataTypeInt16>(&*expression_return))
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt32>());
if (checkDataType<DataTypeUInt32>(&*expression_return) ||
checkDataType<DataTypeUInt64>(&*expression_return) ||
checkDataType<DataTypeInt8>(&*expression_return) ||
checkDataType<DataTypeInt16>(&*expression_return) ||
checkDataType<DataTypeInt32>(&*expression_return) ||
checkDataType<DataTypeInt64>(&*expression_return))
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt64>());
@ -739,35 +744,7 @@ struct ArrayDifferenceImpl
const ColumnVector<Element> * column = checkAndGetColumn<ColumnVector<Element>>(&*mapped);
if (!column)
{
const ColumnConst * column_const = checkAndGetColumnConst<ColumnVector<Element>>(&*mapped);
if (!column_const)
return false;
const IColumn::Offsets & offsets = array.getOffsets();
auto res_nested = ColumnVector<Result>::create();
typename ColumnVector<Result>::Container & res_values = res_nested->getData();
res_values.resize(column_const->size());
size_t pos = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
// skip empty arrays
if (pos < offsets[i])
{
res_values[pos++] = 0;
for (; pos < offsets[i]; ++pos)
{
res_values[pos] = 0;
}
}
}
res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr());
return true;
}
return false;
const IColumn::Offsets & offsets = array.getOffsets();
const typename ColumnVector<Element>::Container & data = column->getData();
@ -798,12 +775,12 @@ struct ArrayDifferenceImpl
{
ColumnPtr res;
if (executeType< UInt8 , Int64>(mapped, array, res) ||
executeType< UInt16, Int64>(mapped, array, res) ||
if (executeType< UInt8 , Int16>(mapped, array, res) ||
executeType< UInt16, Int32>(mapped, array, res) ||
executeType< UInt32, Int64>(mapped, array, res) ||
executeType< UInt64, Int64>(mapped, array, res) ||
executeType< Int8 , Int64>(mapped, array, res) ||
executeType< Int16, Int64>(mapped, array, res) ||
executeType< Int8 , Int16>(mapped, array, res) ||
executeType< Int16, Int32>(mapped, array, res) ||
executeType< Int32, Int64>(mapped, array, res) ||
executeType< Int64, Int64>(mapped, array, res) ||
executeType<Float32,Float64>(mapped, array, res) ||
@ -816,8 +793,9 @@ struct ArrayDifferenceImpl
};
struct ArrayCumSumLimitedImpl
struct ArrayCumSumNonNegativeImpl
{
static bool useDefaultImplementationForConstants() { return true; }
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
@ -840,7 +818,7 @@ struct ArrayCumSumLimitedImpl
checkDataType<DataTypeFloat64>(&*expression_return))
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
throw Exception("arrayCumSumLimited cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("arrayCumSumNonNegativeImpl cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -850,39 +828,7 @@ struct ArrayCumSumLimitedImpl
const ColumnVector<Element> * column = checkAndGetColumn<ColumnVector<Element>>(&*mapped);
if (!column)
{
const ColumnConst * column_const = checkAndGetColumnConst<ColumnVector<Element>>(&*mapped);
if (!column_const)
return false;
const Element x = column_const->template getValue<Element>();
const IColumn::Offsets & offsets = array.getOffsets();
auto res_nested = ColumnVector<Result>::create();
typename ColumnVector<Result>::Container & res_values = res_nested->getData();
res_values.resize(column_const->size());
size_t pos = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
// skip empty arrays
if (pos < offsets[i])
{
res_values[pos++] = x;
for (; pos < offsets[i]; ++pos)
{
res_values[pos] = res_values[pos - 1] + x;
if(res_values[pos] < 0){
res_values[pos] = 0;
}
}
}
}
res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr());
return true;
}
return false;
const IColumn::Offsets & offsets = array.getOffsets();
const typename ColumnVector<Element>::Container & data = column->getData();
@ -892,18 +838,19 @@ struct ArrayCumSumLimitedImpl
res_values.resize(data.size());
size_t pos = 0;
Result accum_sum = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
// skip empty arrays
if (pos < offsets[i])
{
res_values[pos] = data[pos];
accum_sum = data[pos];
res_values[pos] = accum_sum > 0 ? accum_sum : 0;
for (++pos; pos < offsets[i]; ++pos)
{
res_values[pos] = res_values[pos - 1] + data[pos];
if(res_values[pos] < 0){
res_values[pos] = 0;
}
accum_sum = accum_sum + data[pos];
res_values[pos] = accum_sum > 0 ? accum_sum : 0;
}
}
}
@ -928,7 +875,7 @@ struct ArrayCumSumLimitedImpl
executeType<Float64,Float64>(mapped, array, res))
return res;
else
throw Exception("Unexpected column for arrayCumSumLimited: " + mapped->getName());
throw Exception("Unexpected column for arrayCumSumNonNegativeImpl: " + mapped->getName());
}
};
@ -1186,7 +1133,7 @@ struct NameArrayFirstIndex { static constexpr auto name = "arrayFirstIndex"; };
struct NameArraySort { static constexpr auto name = "arraySort"; };
struct NameArrayReverseSort { static constexpr auto name = "arrayReverseSort"; };
struct NameArrayCumSum { static constexpr auto name = "arrayCumSum"; };
struct NameArrayCumSumLimited { static constexpr auto name = "arrayCumSumLimited"; };
struct NameArrayCumSumNonNegative { static constexpr auto name = "arrayCumSumNonNegative"; };
struct NameArrayDifference { static constexpr auto name = "arrayDifference"; };
using FunctionArrayMap = FunctionArrayMapped<ArrayMapImpl, NameArrayMap>;
@ -1200,7 +1147,7 @@ using FunctionArrayFirstIndex = FunctionArrayMapped<ArrayFirstIndexImpl, NameArr
using FunctionArraySort = FunctionArrayMapped<ArraySortImpl<true>, NameArraySort>;
using FunctionArrayReverseSort = FunctionArrayMapped<ArraySortImpl<false>, NameArrayReverseSort>;
using FunctionArrayCumSum = FunctionArrayMapped<ArrayCumSumImpl, NameArrayCumSum>;
using FunctionArrayCumSumLimited = FunctionArrayMapped<ArrayCumSumLimitedImpl, NameArrayCumSumLimited>;
using FunctionArrayCumSumNonNegative = FunctionArrayMapped<ArrayCumSumNonNegativeImpl, NameArrayCumSumNonNegative>;
using FunctionArrayDifference = FunctionArrayMapped<ArrayDifferenceImpl, NameArrayDifference>;
}

View File

@ -1,8 +1,8 @@
[1,3,6,10]
[1,0,5,3]
[1,0,1,0]
[0,1,1,1]
[0,6,93,-95]
[1,0,0,1]
[1,0,0,0]
[1,1.4,1.2999999999999998]
[1, 4, 5]
[0,-4,3,1]

View File

@ -1,8 +1,8 @@
DROP TABLE IF EXISTS test.test;
SELECT arrayCumSumLimited([1, 2, 3, 4]);
SELECT arrayCumSumNonNegative([1, 2, 3, 4]);
SELECT arrayCumSumLimited([1, -5, 5, -2]);
SELECT arrayCumSumNonNegative([1, -5, 5, -2]);
SELECT arrayDifference([1, 2, 3, 4]);
@ -12,11 +12,11 @@ CREATE TABLE test.test(a Array(Int64), b Array(Float64), c Array(UInt64)) ENGINE
INSERT INTO test.test VALUES ([1, -3, 0, 1], [1.0, 0.4, -0.1], [1, 3, 1]);
SELECT arrayCumSumLimited(a) FROM test.test;
SELECT arrayCumSumNonNegative(a) FROM test.test;
SELECT arrayCumSumLimited(b) FROM test.test;
SELECT arrayCumSumNonNegative(b) FROM test.test;
SELECT arrayCumSumLimited(c) FROM test.test;
SELECT arrayCumSumNonNegative(c) FROM test.test;
SELECT arrayDifference(a) FROM test.test;