#include #include #include #include #include #include #include #include #include #include #include #include #include namespace { using namespace DB; Field executeFunctionOnField( const Field & field, const std::string & name, const ExpressionActionsPtr & sharding_expr, const DataTypePtr & type, const std::string & sharding_key_column_name) { ColumnWithTypeAndName column; column.column = type->createColumnConst(1, field); column.name = name; column.type = type; Block block{column}; size_t num_rows = 1; sharding_expr->execute(block, num_rows); ColumnWithTypeAndName & ret = block.getByName(sharding_key_column_name); return (*ret.column)[0]; } /// @param column_value - one of values from IN /// @param sharding_column_name - name of that column /// @return true if shard may contain such value (or it is unknown), otherwise false. bool shardContains( Field column_value, const std::string & sharding_column_name, const OptimizeShardingKeyRewriteInMatcher::Data & data) { /// Type of column in storage (used for implicit conversion from i.e. String to Int) const DataTypePtr & column_type = data.sharding_key_expr->getSampleBlock().getByName(sharding_column_name).type; /// Implicit conversion. column_value = convertFieldToType(column_value, *column_type); /// NULL is not allowed in sharding key, /// so it should be safe to assume that shard cannot contain it. if (column_value.isNull()) return false; Field sharding_value = executeFunctionOnField( column_value, sharding_column_name, data.sharding_key_expr, column_type, data.sharding_key_column_name); /// The value from IN can be non-numeric, /// but in this case it should be convertible to numeric type, let's try. /// /// NOTE: that conversion should not be done for signed types, /// since it uses accurate cast, that will return Null, /// but we need static_cast<> (as createBlockSelector()). if (!isInt64OrUInt64FieldType(sharding_value.getType())) sharding_value = convertFieldToType(sharding_value, DataTypeUInt64()); /// In case of conversion is not possible (NULL), shard cannot contain the value anyway. if (sharding_value.isNull()) return false; UInt64 value = sharding_value.get(); const auto shard_num = data.slots[value % data.slots.size()] + 1; return data.shard_info.shard_num == shard_num; } } namespace DB { bool OptimizeShardingKeyRewriteInMatcher::needChildVisit(ASTPtr & /*node*/, const ASTPtr & /*child*/) { return true; } void OptimizeShardingKeyRewriteInMatcher::visit(ASTPtr & node, Data & data) { if (auto * function = node->as()) visit(*function, data); } void OptimizeShardingKeyRewriteInMatcher::visit(ASTFunction & function, Data & data) { if (function.name != "in") return; auto * left = function.arguments->children.front().get(); auto * right = function.arguments->children.back().get(); auto * identifier = left->as(); if (!identifier) return; auto name = identifier->shortName(); if (!data.sharding_key_expr->getRequiredColumnsWithTypes().contains(name)) return; if (auto * tuple_func = right->as(); tuple_func && tuple_func->name == "tuple") { auto * tuple_elements = tuple_func->children.front()->as(); std::erase_if(tuple_elements->children, [&](auto & child) { auto * literal = child->template as(); return tuple_elements->children.size() > 1 && literal && !shardContains(literal->value, name, data); }); } else if (auto * tuple_literal = right->as(); tuple_literal && tuple_literal->value.getType() == Field::Types::Tuple) { auto & tuple = tuple_literal->value.get(); std::erase_if(tuple, [&](auto & child) { return tuple.size() > 1 && !shardContains(child, name, data); }); } } class OptimizeShardingKeyRewriteIn : public InDepthQueryTreeVisitorWithContext { public: using Base = InDepthQueryTreeVisitorWithContext; OptimizeShardingKeyRewriteIn(OptimizeShardingKeyRewriteInVisitor::Data data_, ContextPtr context) : Base(std::move(context)) , data(std::move(data_)) {} void enterImpl(QueryTreeNodePtr & node) { auto * function_node = node->as(); if (!function_node || function_node->getFunctionName() != "in") return; auto & arguments = function_node->getArguments().getNodes(); auto * column = arguments[0]->as(); if (!column) return; auto name = column->getColumnName(); if (!data.sharding_key_expr->getRequiredColumnsWithTypes().contains(column->getColumnName())) return; if (auto * constant = arguments[1]->as()) { if (isTuple(constant->getResultType())) { const auto & tuple = constant->getValue().get(); Tuple new_tuple; new_tuple.reserve(tuple.size()); for (const auto & child : tuple) { if (shardContains(child, name, data)) new_tuple.push_back(child); } if (new_tuple.empty()) new_tuple.push_back(tuple.back()); if (new_tuple.size() == tuple.size()) return; arguments[1] = std::make_shared(new_tuple); rerunFunctionResolve(function_node, getContext()); } } } OptimizeShardingKeyRewriteInVisitor::Data data; }; void optimizeShardingKeyRewriteIn(QueryTreeNodePtr & node, OptimizeShardingKeyRewriteInVisitor::Data data, ContextPtr context) { OptimizeShardingKeyRewriteIn visitor(std::move(data), std::move(context)); visitor.visit(node); } }