Change how short circuit optimization works

This commit is contained in:
Raúl Marín 2024-05-16 21:35:58 +02:00
parent 4afb14e234
commit 4680d09e9a
3 changed files with 41 additions and 149 deletions

View File

@ -279,25 +279,32 @@ void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> &
if (!column_function)
return;
size_t original_size = column.column->size();
ColumnWithTypeAndName result;
/// If mask contains only zeros, we can just create
/// an empty column with the execution result type.
if (!mask_info.has_ones)
{
/// If mask contains only zeros, we can just create a column with default values as it will be ignored
auto result_type = column_function->getResultType();
auto empty_column = result_type->createColumn();
result = {std::move(empty_column), result_type, ""};
auto default_column = result_type->createColumnConstWithDefaultValue(original_size)->convertToFullColumnIfConst();
column = {std::move(default_column), result_type, ""};
}
/// Filter column only if mask contains zeros.
else if (mask_info.has_zeros)
{
/// If it contains both zeros and ones, we need to execute the function only on the mask values
/// First we filter the column, which creates a new column, then we apply the column, and finally we expand it
/// Expanding is done to keep consistency in function calls (all columns the same size) and it's ok
/// since the values won't be used by `if`
auto filtered = column_function->filter(mask, -1);
result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce();
auto filter_after_execution = typeid_cast<const ColumnFunction *>(filtered.get())->reduce();
auto mut_column = IColumn::mutate(std::move(filter_after_execution.column));
mut_column->expand(mask, false);
column.column = std::move(mut_column);
}
else
result = column_function->reduce();
column = column_function->reduce();
column = std::move(result);
chassert(column.column->size() == original_size);
}
void executeColumnIfNeeded(ColumnWithTypeAndName & column, bool empty)

View File

@ -76,75 +76,17 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
{
size_t size = cond.size();
bool a_is_short = a.size() < size;
bool b_is_short = b.size() < size;
if (a_is_short && b_is_short)
for (size_t i = 0; i < size; ++i)
{
size_t a_index = 0, b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[b_index]);
a_index += !!cond[i];
b_index += !cond[i];
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i])
}
}
else if (a_is_short)
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
else
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[i]);
a_index += !!cond[i];
}
}
else if (b_is_short)
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i])
}
else
{
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
}
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
}
}
}
@ -153,37 +95,16 @@ template <typename ArrayCond, typename ArrayA, typename B, typename ArrayResult,
inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, ArrayResult & res)
{
size_t size = cond.size();
bool a_is_short = a.size() < size;
if (a_is_short)
for (size_t i = 0; i < size; ++i)
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else if constexpr (std::is_floating_point_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b);
a_index += !!cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
}
}
@ -191,37 +112,16 @@ template <typename ArrayCond, typename A, typename ArrayB, typename ArrayResult,
inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, ArrayResult & res)
{
size_t size = cond.size();
bool b_is_short = b.size() < size;
if (b_is_short)
for (size_t i = 0; i < size; ++i)
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
}
}
@ -879,9 +779,6 @@ private:
bool then_is_const = isColumnConst(*col_then);
bool else_is_const = isColumnConst(*col_else);
bool then_is_short = col_then->size() < cond_col->size();
bool else_is_short = col_else->size() < cond_col->size();
const auto & cond_array = cond_col->getData();
if (then_is_const && else_is_const)
@ -901,37 +798,34 @@ private:
{
const IColumn & then_nested_column = assert_cast<const ColumnConst &>(*col_then).getDataColumn();
size_t else_index = 0;
for (size_t i = 0; i < input_rows_count; ++i)
{
if (cond_array[i])
result_column->insertFrom(then_nested_column, 0);
else
result_column->insertFrom(*col_else, else_is_short ? else_index++ : i);
result_column->insertFrom(*col_else, i);
}
}
else if (else_is_const)
{
const IColumn & else_nested_column = assert_cast<const ColumnConst &>(*col_else).getDataColumn();
size_t then_index = 0;
for (size_t i = 0; i < input_rows_count; ++i)
{
if (cond_array[i])
result_column->insertFrom(*col_then, then_is_short ? then_index++ : i);
result_column->insertFrom(*col_then, i);
else
result_column->insertFrom(else_nested_column, 0);
}
}
else
{
size_t then_index = 0, else_index = 0;
for (size_t i = 0; i < input_rows_count; ++i)
{
if (cond_array[i])
result_column->insertFrom(*col_then, then_is_short ? then_index++ : i);
result_column->insertFrom(*col_then, i);
else
result_column->insertFrom(*col_else, else_is_short ? else_index++ : i);
result_column->insertFrom(*col_else, i);
}
}
@ -1124,9 +1018,6 @@ private:
if (then_is_null && else_is_null)
return result_type->createColumnConstWithDefaultValue(input_rows_count);
bool then_is_short = arg_then.column->size() < arg_cond.column->size();
bool else_is_short = arg_else.column->size() < arg_cond.column->size();
const ColumnUInt8 * cond_col = typeid_cast<const ColumnUInt8 *>(arg_cond.column.get());
const ColumnConst * cond_const_col = checkAndGetColumnConst<ColumnVector<UInt8>>(arg_cond.column.get());
@ -1145,8 +1036,6 @@ private:
{
arg_else_column = arg_else_column->convertToFullColumnIfConst();
auto result_column = IColumn::mutate(std::move(arg_else_column));
if (else_is_short)
result_column->expand(cond_col->getData(), true);
if (isColumnNullable(*result_column))
{
assert_cast<ColumnNullable &>(*result_column).applyNullMap(assert_cast<const ColumnUInt8 &>(*arg_cond.column));
@ -1187,8 +1076,6 @@ private:
{
arg_then_column = arg_then_column->convertToFullColumnIfConst();
auto result_column = IColumn::mutate(std::move(arg_then_column));
if (then_is_short)
result_column->expand(cond_col->getData(), false);
if (isColumnNullable(*result_column))
{

View File

@ -1,6 +1,4 @@
-- Tags: no-parallel, disabled
-- Disabled while I investigate so CI keeps running (but it's broken)
-- Tags: no-parallel
DROP TABLE IF EXISTS dictionary_source_table;
CREATE TABLE dictionary_source_table