This commit is contained in:
Evgeniy Gatov 2015-05-20 20:40:24 +03:00
commit 39a9f312c0
5 changed files with 209 additions and 85 deletions

View File

@ -12,6 +12,27 @@
namespace DB namespace DB
{ {
namespace
{
/// Эта функция возвращает true если оба значения велики и сравнимы.
/// Она употребляется для вычисления среднего значения путём слияния двух источников.
/// Ибо если размеры обоих источников велики и сравнимы, то надо применить особенную
/// формулу гарантирующую больше стабильности.
bool areComparable(UInt64 a, UInt64 b)
{
const Float64 sensitivity = 0.001;
const UInt64 threshold = 10000;
if ((a == 0) || (b == 0))
return false;
auto res = std::minmax(a, b);
return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold));
}
}
/** Статистические аггрегатные функции: /** Статистические аггрегатные функции:
* varSamp - выборочная дисперсия * varSamp - выборочная дисперсия
* stddevSamp - среднее выборочное квадратичное отклонение * stddevSamp - среднее выборочное квадратичное отклонение
@ -52,12 +73,8 @@ public:
Float64 factor = static_cast<Float64>(count * source.count) / total_count; Float64 factor = static_cast<Float64>(count * source.count) / total_count;
Float64 delta = mean - source.mean; Float64 delta = mean - source.mean;
auto res = std::minmax(count, source.count); if (areComparable(count, source.count))
if (((1 - static_cast<Float64>(res.first) / res.second) < 0.001) && (res.first > 10000))
{
/// Эта формула более стабильная, когда размеры обоих источников велики и сравнимы.
mean = (source.count * source.mean + count * mean) / total_count; mean = (source.count * source.mean + count * mean) / total_count;
}
else else
mean = source.mean + delta * (static_cast<Float64>(count) / total_count); mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
@ -93,7 +110,9 @@ private:
/** Основной код для реализации функций varSamp, stddevSamp, varPop, stddevPop. /** Основной код для реализации функций varSamp, stddevSamp, varPop, stddevPop.
*/ */
template<typename T, typename Op> template<typename T, typename Op>
class AggregateFunctionVariance final : public IUnaryAggregateFunction<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op> > class AggregateFunctionVariance final
: public IUnaryAggregateFunction<AggregateFunctionVarianceData<T, Op>,
AggregateFunctionVariance<T, Op> >
{ {
public: public:
String getName() const override { return Op::name; } String getName() const override { return Op::name; }
@ -151,7 +170,7 @@ struct VarSampImpl
static inline Float64 apply(Float64 m2, UInt64 count) static inline Float64 apply(Float64 m2, UInt64 count)
{ {
if (count < 2) if (count < 2)
return 0.0; return std::numeric_limits<Float64>::infinity();
else else
return m2 / (count - 1); return m2 / (count - 1);
} }
@ -177,7 +196,9 @@ struct VarPopImpl
static inline Float64 apply(Float64 m2, UInt64 count) static inline Float64 apply(Float64 m2, UInt64 count)
{ {
if (count < 2) if (count == 0)
return std::numeric_limits<Float64>::infinity();
else if (count == 1)
return 0.0; return 0.0;
else else
return m2 / count; return m2 / count;
@ -198,26 +219,73 @@ struct StdDevPopImpl
} }
/** Если флаг compute_marginal_moments установлен, этот класс предоставялет наследнику
* CovarianceData поддержку маргинальных моментов для вычисления корреляции.
*/
template<bool compute_marginal_moments>
class BaseCovarianceData
{
protected:
void incrementMarginalMoments(Float64 left_incr, Float64 right_incr) {}
void mergeWith(const BaseCovarianceData & source) {}
void serialize(WriteBuffer & buf) const {}
void deserialize(const ReadBuffer & buf) {}
};
template<>
class BaseCovarianceData<true>
{
protected:
void incrementMarginalMoments(Float64 left_incr, Float64 right_incr)
{
left_m2 += left_incr;
right_m2 += right_incr;
}
void mergeWith(const BaseCovarianceData & source)
{
left_m2 += source.left_m2;
right_m2 += source.right_m2;
}
void serialize(WriteBuffer & buf) const
{
writeBinary(left_m2, buf);
writeBinary(right_m2, buf);
}
void deserialize(ReadBuffer & buf)
{
readBinary(left_m2, buf);
readBinary(right_m2, buf);
}
protected:
Float64 left_m2 = 0.0;
Float64 right_m2 = 0.0;
};
/** Параллельный и инкрементальный алгоритм для вычисления ковариации. /** Параллельный и инкрементальный алгоритм для вычисления ковариации.
* Источник: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms" * Источник: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms"
* (J. Bennett et al., Sandia National Laboratories, * (J. Bennett et al., Sandia National Laboratories,
* 2009 IEEE International Conference on Cluster Computing) * 2009 IEEE International Conference on Cluster Computing)
*/ */
template<typename T, typename U, typename Op, bool compute_marginal_moments> template<typename T, typename U, typename Op, bool compute_marginal_moments>
class CovarianceData class CovarianceData : public BaseCovarianceData<compute_marginal_moments>
{ {
public: private:
CovarianceData() = default; using Base = BaseCovarianceData<compute_marginal_moments>;
public:
void update(const IColumn & column_left, const IColumn & column_right, size_t row_num) void update(const IColumn & column_left, const IColumn & column_right, size_t row_num)
{ {
T left_received = static_cast<const ColumnVector<T> &>(column_left).getData()[row_num]; T left_received = static_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
Float64 val_left = static_cast<Float64>(left_received); Float64 left_val = static_cast<Float64>(left_received);
Float64 left_delta = val_left - left_mean; Float64 left_delta = left_val - left_mean;
U right_received = static_cast<const ColumnVector<U> &>(column_right).getData()[row_num]; U right_received = static_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
Float64 val_right = static_cast<Float64>(right_received); Float64 right_val = static_cast<Float64>(right_received);
Float64 right_delta = val_right - right_mean; Float64 right_delta = right_val - right_mean;
Float64 old_right_mean = right_mean; Float64 old_right_mean = right_mean;
@ -225,12 +293,14 @@ public:
left_mean += left_delta / count; left_mean += left_delta / count;
right_mean += right_delta / count; right_mean += right_delta / count;
co_moment += (val_left - left_mean) * (val_right - old_right_mean); co_moment += (left_val - left_mean) * (right_val - old_right_mean);
/// Обновить маргинальные моменты, если они есть.
if (compute_marginal_moments) if (compute_marginal_moments)
{ {
left_m2 += left_delta * (val_left - left_mean); Float64 left_incr = left_delta * (left_val - left_mean);
right_m2 += right_delta * (val_right - right_mean); Float64 right_incr = right_delta * (right_val - right_mean);
Base::incrementMarginalMoments(left_incr, right_incr);
} }
} }
@ -244,15 +314,27 @@ public:
Float64 left_delta = left_mean - source.left_mean; Float64 left_delta = left_mean - source.left_mean;
Float64 right_delta = right_mean - source.right_mean; Float64 right_delta = right_mean - source.right_mean;
if (areComparable(count, source.count))
{
left_mean = (source.count * source.left_mean + count * left_mean) / total_count;
right_mean = (source.count * source.right_mean + count * right_mean) / total_count;
}
else
{
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count); left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count); right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
}
co_moment += source.co_moment + left_delta * right_delta * factor; co_moment += source.co_moment + left_delta * right_delta * factor;
count = total_count; count = total_count;
/// Обновить маргинальные моменты, если они есть.
if (compute_marginal_moments) if (compute_marginal_moments)
{ {
left_m2 += source.left_m2 + left_delta * left_delta * factor; Float64 left_incr = left_delta * left_delta * factor;
right_m2 += source.right_m2 + right_delta * right_delta * factor; Float64 right_incr = right_delta * right_delta * factor;
Base::mergeWith(source);
Base::incrementMarginalMoments(left_incr, right_incr);
} }
} }
@ -262,12 +344,7 @@ public:
writeBinary(left_mean, buf); writeBinary(left_mean, buf);
writeBinary(right_mean, buf); writeBinary(right_mean, buf);
writeBinary(co_moment, buf); writeBinary(co_moment, buf);
Base::serialize(buf);
if (compute_marginal_moments)
{
writeBinary(left_m2, buf);
writeBinary(right_m2, buf);
}
} }
void deserialize(ReadBuffer & buf) void deserialize(ReadBuffer & buf)
@ -276,17 +353,19 @@ public:
readBinary(left_mean, buf); readBinary(left_mean, buf);
readBinary(right_mean, buf); readBinary(right_mean, buf);
readBinary(co_moment, buf); readBinary(co_moment, buf);
Base::deserialize(buf);
if (compute_marginal_moments)
{
readBinary(left_m2, buf);
readBinary(right_m2, buf);
}
} }
void publish(IColumn & to) const template<bool compute = compute_marginal_moments>
void publish(IColumn & to, typename std::enable_if<compute>::type * = nullptr) const
{ {
static_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, left_m2, right_m2, count)); static_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, Base::left_m2, Base::right_m2, count));
}
template<bool compute = compute_marginal_moments>
void publish(IColumn & to, typename std::enable_if<!compute>::type * = nullptr) const
{
static_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, count));
} }
private: private:
@ -294,8 +373,6 @@ private:
Float64 left_mean = 0.0; Float64 left_mean = 0.0;
Float64 right_mean = 0.0; Float64 right_mean = 0.0;
Float64 co_moment = 0.0; Float64 co_moment = 0.0;
Float64 left_m2 = 0.0;
Float64 right_m2 = 0.0;
}; };
template<typename T, typename U, typename Op, bool compute_marginal_moments = false> template<typename T, typename U, typename Op, bool compute_marginal_moments = false>
@ -342,7 +419,6 @@ public:
{ {
CovarianceData<T, U, Op, compute_marginal_moments> source; CovarianceData<T, U, Op, compute_marginal_moments> source;
source.deserialize(buf); source.deserialize(buf);
this->data(place).mergeWith(source); this->data(place).mergeWith(source);
} }
@ -361,10 +437,10 @@ struct CovarSampImpl
{ {
static constexpr auto name = "covarSamp"; static constexpr auto name = "covarSamp";
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) static inline Float64 apply(Float64 co_moment, UInt64 count)
{ {
if (count < 2) if (count < 2)
return 0.0; return std::numeric_limits<Float64>::infinity();
else else
return co_moment / (count - 1); return co_moment / (count - 1);
} }
@ -376,9 +452,11 @@ struct CovarPopImpl
{ {
static constexpr auto name = "covarPop"; static constexpr auto name = "covarPop";
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) static inline Float64 apply(Float64 co_moment, UInt64 count)
{ {
if (count < 2) if (count == 0)
return std::numeric_limits<Float64>::infinity();
else if (count == 1)
return 0.0; return 0.0;
else else
return co_moment / count; return co_moment / count;
@ -394,7 +472,7 @@ struct CorrImpl
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
{ {
if (count < 2) if (count < 2)
return 0.0; return std::numeric_limits<Float64>::infinity();
else else
return co_moment / sqrt(left_m2 * right_m2); return co_moment / sqrt(left_m2 * right_m2);
} }

View File

@ -94,75 +94,103 @@ namespace DB
/// Реализация функций округления на низком уровне. /// Реализация функций округления на низком уровне.
template<typename T, int rounding_mode> template<typename T, int rounding_mode, bool with_scale>
struct RoundingComputation struct RoundingComputation
{ {
}; };
template<int rounding_mode> template<int rounding_mode, bool with_scale>
struct RoundingComputation<Float32, rounding_mode> struct RoundingComputation<Float32, rounding_mode, with_scale>
{ {
using Data = std::array<Float32, 4>; using Data = std::array<Float32, 4>;
using Scale = __m128; using Scale = __m128;
static inline void prepareScale(size_t scale, Scale & mm_scale) template<bool with_scale2 = with_scale>
static inline void prepareScale(size_t scale, Scale & mm_scale,
typename std::enable_if<with_scale2>::type * = nullptr)
{ {
Float32 fscale = static_cast<Float32>(scale); Float32 fscale = static_cast<Float32>(scale);
mm_scale = _mm_load1_ps(&fscale); mm_scale = _mm_load1_ps(&fscale);
} }
static inline void compute(const Data & in, const Scale & mm_scale, Data & out) template<bool with_scale2 = with_scale>
static inline void prepareScale(size_t scale, Scale & mm_scale,
typename std::enable_if<!with_scale2>::type * = nullptr)
{ {
Float32 input[4] __attribute__((aligned(16))) = {in[0], in[1], in[2], in[3]}; }
__m128 mm_value = _mm_load_ps(input);
template<bool with_scale2 = with_scale>
static inline void compute(const Data & in, const Scale & mm_scale, Data & out,
typename std::enable_if<with_scale2>::type * = nullptr)
{
__m128 mm_value = _mm_loadu_ps(reinterpret_cast<const Float32 *>(&in));
mm_value = _mm_mul_ps(mm_value, mm_scale); mm_value = _mm_mul_ps(mm_value, mm_scale);
mm_value = _mm_round_ps(mm_value, rounding_mode); mm_value = _mm_round_ps(mm_value, rounding_mode);
mm_value = _mm_div_ps(mm_value, mm_scale); mm_value = _mm_div_ps(mm_value, mm_scale);
_mm_storeu_ps(reinterpret_cast<Float32 *>(&out), mm_value);
}
Float32 res[4] __attribute__((aligned(16))); template<bool with_scale2 = with_scale>
_mm_store_ps(res, mm_value); static inline void compute(const Data & in, const Scale & mm_scale, Data & out,
out = {res[0], res[1], res[2], res[3]}; typename std::enable_if<!with_scale2>::type * = nullptr)
{
__m128 mm_value = _mm_loadu_ps(reinterpret_cast<const Float32 *>(&in));
mm_value = _mm_round_ps(mm_value, rounding_mode);
_mm_storeu_ps(reinterpret_cast<Float32 *>(&out), mm_value);
} }
}; };
template<int rounding_mode> template<int rounding_mode, bool with_scale>
struct RoundingComputation<Float64, rounding_mode> struct RoundingComputation<Float64, rounding_mode, with_scale>
{ {
using Data = std::array<Float64, 2>; using Data = std::array<Float64, 2>;
using Scale = __m128d; using Scale = __m128d;
static inline void prepareScale(size_t scale, Scale & mm_scale) template<bool with_scale2 = with_scale>
static inline void prepareScale(size_t scale, Scale & mm_scale,
typename std::enable_if<with_scale2>::type * = nullptr)
{ {
Float64 fscale = static_cast<Float64>(scale); Float64 fscale = static_cast<Float64>(scale);
mm_scale = _mm_load1_pd(&fscale); mm_scale = _mm_load1_pd(&fscale);
} }
static inline void compute(const Data & in, const Scale & mm_scale, Data & out) template<bool with_scale2 = with_scale>
static inline void prepareScale(size_t scale, Scale & mm_scale,
typename std::enable_if<!with_scale2>::type * = nullptr)
{ {
Float64 input[2] __attribute__((aligned(16))) = { in[0], in[1] }; }
__m128d mm_value = _mm_load_pd(input);
template<bool with_scale2 = with_scale>
static inline void compute(const Data & in, const Scale & mm_scale, Data & out,
typename std::enable_if<with_scale2>::type * = nullptr)
{
__m128d mm_value = _mm_loadu_pd(reinterpret_cast<const Float64 *>(&in));
mm_value = _mm_mul_pd(mm_value, mm_scale); mm_value = _mm_mul_pd(mm_value, mm_scale);
mm_value = _mm_round_pd(mm_value, rounding_mode); mm_value = _mm_round_pd(mm_value, rounding_mode);
mm_value = _mm_div_pd(mm_value, mm_scale); mm_value = _mm_div_pd(mm_value, mm_scale);
_mm_storeu_pd(reinterpret_cast<Float64 *>(&out), mm_value);
}
Float64 res[2] __attribute__((aligned(16))); template<bool with_scale2 = with_scale>
_mm_store_pd(res, mm_value); static inline void compute(const Data & in, const Scale & mm_scale, Data & out,
out = {res[0], res[1]}; typename std::enable_if<!with_scale2>::type * = nullptr)
{
__m128d mm_value = _mm_loadu_pd(reinterpret_cast<const Float64 *>(&in));
mm_value = _mm_round_pd(mm_value, rounding_mode);
_mm_storeu_pd(reinterpret_cast<Float64 *>(&out), mm_value);
} }
}; };
/// Реализация функций округления на высоком уровне. /// Реализация функций округления на высоком уровне.
template<typename T, int rounding_mode, typename Enable = void> template<typename T, int rounding_mode, bool with_scale, typename Enable = void>
struct FunctionRoundingImpl struct FunctionRoundingImpl
{ {
}; };
/// В случае целочисленных значений не выполяется округления. /// В случае целочисленных значений не выполяется округления.
template<typename T, int rounding_mode> template<typename T, int rounding_mode, bool with_scale>
struct FunctionRoundingImpl<T, rounding_mode, typename std::enable_if<std::is_integral<T>::value>::type> struct FunctionRoundingImpl<T, rounding_mode, with_scale, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
static inline void apply(const PODArray<T> & in, size_t scale, typename ColumnVector<T>::Container_t & out) static inline void apply(const PODArray<T> & in, size_t scale, typename ColumnVector<T>::Container_t & out)
{ {
@ -177,11 +205,11 @@ namespace DB
} }
}; };
template<typename T, int rounding_mode> template<typename T, int rounding_mode, bool with_scale>
struct FunctionRoundingImpl<T, rounding_mode, typename std::enable_if<std::is_floating_point<T>::value>::type> struct FunctionRoundingImpl<T, rounding_mode, with_scale, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
private: private:
using Op = RoundingComputation<T, rounding_mode>; using Op = RoundingComputation<T, rounding_mode, with_scale>;
using Data = typename Op::Data; using Data = typename Op::Data;
using Scale = typename Op::Scale; using Scale = typename Op::Scale;
@ -218,7 +246,7 @@ namespace DB
Op::compute(tmp, mm_scale, res); Op::compute(tmp, mm_scale, res);
for (size_t j = 0; (j < data_size) && ((i + j) < size); ++j) for (size_t j = 0; (j < data_size) && ((i + j) < size); ++j)
out[i + j] = in[i + j]; out[i + j] = res[j];
} }
} }
@ -357,21 +385,25 @@ namespace
template<typename T> template<typename T>
bool executeForType(Block & block, const ColumnNumbers & arguments, size_t result) bool executeForType(Block & block, const ColumnNumbers & arguments, size_t result)
{ {
using Op = FunctionRoundingImpl<T, rounding_mode>; using OpWithScale = FunctionRoundingImpl<T, rounding_mode, true>;
using OpWithoutScale = FunctionRoundingImpl<T, rounding_mode, false>;
if (ColumnVector<T> * col = typeid_cast<ColumnVector<T> *>(&*block.getByPosition(arguments[0]).column)) if (ColumnVector<T> * col = typeid_cast<ColumnVector<T> *>(&*block.getByPosition(arguments[0]).column))
{ {
UInt8 precision = 0;
if (arguments.size() == 2)
precision = getPrecision<T>(block.getByPosition(arguments[1]).column);
ColumnVector<T> * col_res = new ColumnVector<T>; ColumnVector<T> * col_res = new ColumnVector<T>;
block.getByPosition(result).column = col_res; block.getByPosition(result).column = col_res;
typename ColumnVector<T>::Container_t & vec_res = col_res->getData(); typename ColumnVector<T>::Container_t & vec_res = col_res->getData();
vec_res.resize(col->getData().size()); vec_res.resize(col->getData().size());
Op::apply(col->getData(), PowersOf10::values[precision], vec_res); UInt8 precision = 0;
if (arguments.size() == 2)
precision = getPrecision<T>(block.getByPosition(arguments[1]).column);
if (precision > 0)
OpWithScale::apply(col->getData(), PowersOf10::values[precision], vec_res);
else
OpWithoutScale::apply(col->getData(), 0, vec_res);
return true; return true;
} }
@ -381,7 +413,11 @@ namespace
if (arguments.size() == 2) if (arguments.size() == 2)
precision = getPrecision<T>(block.getByPosition(arguments[1]).column); precision = getPrecision<T>(block.getByPosition(arguments[1]).column);
T res = Op::apply(col->getData(), PowersOf10::values[precision]); T res;
if (precision > 0)
res = OpWithScale::apply(col->getData(), PowersOf10::values[precision]);
else
res = OpWithoutScale::apply(col->getData(), 0);
ColumnConst<T> * col_res = new ColumnConst<T>(col->size(), res); ColumnConst<T> * col_res = new ColumnConst<T>(col->size(), res);
block.getByPosition(result).column = col_res; block.getByPosition(result).column = col_res;

View File

@ -1,6 +1,5 @@
#include <DB/Storages/MergeTree/MergeTreeDataSelectExecutor.h> #include <DB/Storages/MergeTree/MergeTreeDataSelectExecutor.h>
#include <DB/Storages/MergeTree/MergeTreeBlockInputStream.h> #include <DB/Storages/MergeTree/MergeTreeBlockInputStream.h>
#include <DB/Storages/MergeTree/MergeTreeWhereOptimizer.h>
#include <DB/Interpreters/ExpressionAnalyzer.h> #include <DB/Interpreters/ExpressionAnalyzer.h>
#include <DB/Parsers/ASTIdentifier.h> #include <DB/Parsers/ASTIdentifier.h>
#include <DB/DataStreams/ExpressionBlockInputStream.h> #include <DB/DataStreams/ExpressionBlockInputStream.h>
@ -63,13 +62,6 @@ BlockInputStreams MergeTreeDataSelectExecutor::read(
if (real_column_names.empty()) if (real_column_names.empty())
real_column_names.push_back(ExpressionActions::getSmallestColumn(data.getColumnsList())); real_column_names.push_back(ExpressionActions::getSmallestColumn(data.getColumnsList()));
ASTSelectQuery & select = *typeid_cast<ASTSelectQuery*>(&*query);
/// Try transferring some condition from WHERE to PREWHERE if enabled and viable
if (settings.optimize_move_to_prewhere)
if (select.where_expression && !select.prewhere_expression)
MergeTreeWhereOptimizer{select, data, column_names_to_return, log};
Block virtual_columns_block = getBlockWithVirtualColumns(parts); Block virtual_columns_block = getBlockWithVirtualColumns(parts);
/// Если запрошен хотя бы один виртуальный столбец, пробуем индексировать /// Если запрошен хотя бы один виртуальный столбец, пробуем индексировать
@ -114,6 +106,8 @@ BlockInputStreams MergeTreeDataSelectExecutor::read(
ExpressionActionsPtr filter_expression; ExpressionActionsPtr filter_expression;
double relative_sample_size = 0; double relative_sample_size = 0;
ASTSelectQuery & select = *typeid_cast<ASTSelectQuery*>(&*query);
if (select.sample_size) if (select.sample_size)
{ {
relative_sample_size = apply_visitor(FieldVisitorConvertToNumber<double>(), relative_sample_size = apply_visitor(FieldVisitorConvertToNumber<double>(),

View File

@ -2,6 +2,7 @@
#include <DB/Storages/MergeTree/MergeTreeBlockOutputStream.h> #include <DB/Storages/MergeTree/MergeTreeBlockOutputStream.h>
#include <DB/Storages/MergeTree/DiskSpaceMonitor.h> #include <DB/Storages/MergeTree/DiskSpaceMonitor.h>
#include <DB/Storages/MergeTree/MergeList.h> #include <DB/Storages/MergeTree/MergeList.h>
#include <DB/Storages/MergeTree/MergeTreeWhereOptimizer.h>
#include <DB/Common/escapeForFileName.h> #include <DB/Common/escapeForFileName.h>
#include <DB/Interpreters/InterpreterAlterQuery.h> #include <DB/Interpreters/InterpreterAlterQuery.h>
#include <Poco/DirectoryIterator.h> #include <Poco/DirectoryIterator.h>
@ -98,6 +99,13 @@ BlockInputStreams StorageMergeTree::read(
const size_t max_block_size, const size_t max_block_size,
const unsigned threads) const unsigned threads)
{ {
ASTSelectQuery & select = *typeid_cast<ASTSelectQuery*>(&*query);
/// Try transferring some condition from WHERE to PREWHERE if enabled and viable
if (settings.optimize_move_to_prewhere)
if (select.where_expression && !select.prewhere_expression)
MergeTreeWhereOptimizer{select, data, column_names, log};
return reader.read(column_names, query, context, settings, processed_stage, max_block_size, threads); return reader.read(column_names, query, context, settings, processed_stage, max_block_size, threads);
} }

View File

@ -5,6 +5,7 @@
#include <DB/Storages/MergeTree/ReplicatedMergeTreePartsExchange.h> #include <DB/Storages/MergeTree/ReplicatedMergeTreePartsExchange.h>
#include <DB/Storages/MergeTree/MergeTreePartChecker.h> #include <DB/Storages/MergeTree/MergeTreePartChecker.h>
#include <DB/Storages/MergeTree/MergeList.h> #include <DB/Storages/MergeTree/MergeList.h>
#include <DB/Storages/MergeTree/MergeTreeWhereOptimizer.h>
#include <DB/Parsers/formatAST.h> #include <DB/Parsers/formatAST.h>
#include <DB/IO/WriteBufferFromOStream.h> #include <DB/IO/WriteBufferFromOStream.h>
#include <DB/IO/ReadBufferFromString.h> #include <DB/IO/ReadBufferFromString.h>
@ -1996,6 +1997,13 @@ BlockInputStreams StorageReplicatedMergeTree::read(
else else
real_column_names.push_back(it); real_column_names.push_back(it);
ASTSelectQuery & select = *typeid_cast<ASTSelectQuery*>(&*query);
/// Try transferring some condition from WHERE to PREWHERE if enabled and viable
if (settings.optimize_move_to_prewhere)
if (select.where_expression && !select.prewhere_expression)
MergeTreeWhereOptimizer{select, data, real_column_names, log};
Block virtual_columns_block; Block virtual_columns_block;
ColumnUInt8 * column = new ColumnUInt8(2); ColumnUInt8 * column = new ColumnUInt8(2);
ColumnPtr column_ptr = column; ColumnPtr column_ptr = column;