diff --git a/dbms/include/DB/Core/ErrorCodes.h b/dbms/include/DB/Core/ErrorCodes.h index c57cef242fe..11bf7f6576a 100644 --- a/dbms/include/DB/Core/ErrorCodes.h +++ b/dbms/include/DB/Core/ErrorCodes.h @@ -65,6 +65,7 @@ namespace ErrorCodes UNKNOWN_STORAGE, TABLE_ALREADY_EXISTS, TABLE_METADATA_ALREADY_EXISTS, + ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER, }; } diff --git a/dbms/src/DataStreams/FilterBlockInputStream.cpp b/dbms/src/DataStreams/FilterBlockInputStream.cpp index 5ed7f1396e5..f2de12ea784 100644 --- a/dbms/src/DataStreams/FilterBlockInputStream.cpp +++ b/dbms/src/DataStreams/FilterBlockInputStream.cpp @@ -18,7 +18,20 @@ Block FilterBlockInputStream::read() return res; size_t columns = res.columns(); - IColumn::Filter & filter = dynamic_cast(*res.getByPosition(filter_column).column).getData(); + + ColumnConstUInt8 * column_const = dynamic_cast(&*res.getByPosition(filter_column).column); + if (column_const) + { + return column_const->getData() + ? res + : Block(); + } + + ColumnUInt8 * column = dynamic_cast(&*res.getByPosition(filter_column).column); + if (!column) + throw Exception("Illegal type of column for filter. Must be ColumnUInt8 or ColumnConstUInt8.", ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER); + + IColumn::Filter & filter = column->getData(); for (size_t i = 0; i < columns; ++i) if (i != filter_column) diff --git a/dbms/src/DataStreams/tests/filter_stream.cpp b/dbms/src/DataStreams/tests/filter_stream.cpp new file mode 100644 index 00000000000..28487223ce4 --- /dev/null +++ b/dbms/src/DataStreams/tests/filter_stream.cpp @@ -0,0 +1,97 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + + +using Poco::SharedPtr; + + +int main(int argc, char ** argv) +{ + try + { + size_t n = argc == 2 ? Poco::NumberParser::parseUnsigned64(argv[1]) : 10ULL; + + DB::ParserSelectQuery parser; + DB::ASTPtr ast; + std::string input = "SELECT number, number % 3 == 1"; + std::string expected; + + const char * begin = input.data(); + const char * end = begin + input.size(); + const char * pos = begin; + + if (!parser.parse(pos, end, ast, expected)) + { + std::cout << "Failed at position " << (pos - begin) << ": " + << mysqlxx::quote << input.substr(pos - begin, 10) + << ", expected " << expected << "." << std::endl; + } + + DB::formatAST(*ast, std::cerr); + std::cerr << std::endl; + std::cerr << ast->getTreeID() << std::endl; + + DB::Context context; + context.columns["number"] = new DB::DataTypeUInt64; + (*context.functions)["modulo"] = new DB::FunctionModulo; + (*context.functions)["equals"] = new DB::FunctionEquals; + (*context.functions)["notEquals"] = new DB::FunctionNotEquals; + + Poco::SharedPtr expression = new DB::Expression(ast, context); + + DB::StorageSystemNumbers table("Numbers"); + + DB::Names column_names; + column_names.push_back("number"); + + Poco::SharedPtr in = table.read(column_names, 0); + in = new DB::ExpressionBlockInputStream(in, expression); + in = new DB::FilterBlockInputStream(in, 1); + in = new DB::LimitBlockInputStream(in, 10, std::max(static_cast(0), static_cast(n) - 10)); + + DB::WriteBufferFromOStream ob(std::cout); + DB::TabSeparatedRowOutputStream out(ob, new DB::DataTypes(expression->getReturnTypes())); + + { + Poco::Stopwatch stopwatch; + stopwatch.start(); + + DB::copyData(*in, out); + + stopwatch.stop(); + std::cout << std::fixed << std::setprecision(2) + << "Elapsed " << stopwatch.elapsed() / 1000000.0 << " sec." + << ", " << n * 1000000 / stopwatch.elapsed() << " rows/sec." + << std::endl; + } + } + catch (const DB::Exception & e) + { + std::cerr << e.what() << ", " << e.message() << std::endl; + return 1; + } + + return 0; +}