Merge branch 'master' of github.com:yandex/ClickHouse

This commit is contained in:
Alexey Milovidov 2018-09-10 15:40:15 +03:00
commit 85e8c7920f
241 changed files with 11461 additions and 9388 deletions

View File

@ -29,7 +29,7 @@ if (ENABLE_CAPNP)
find_library (CAPNP capnp PATHS ${CAPNP_PATHS})
find_library (CAPNPC capnpc PATHS ${CAPNP_PATHS})
find_library (KJ kj PATHS ${CAPNP_PATHS})
set (CAPNP_LIBRARY ${CAPNP} ${CAPNPC} ${KJ})
set (CAPNP_LIBRARY ${CAPNPC} ${CAPNP} ${KJ})
find_path (CAPNP_INCLUDE_DIR NAMES capnp/schema-parser.h PATHS ${CAPNP_INCLUDE_PATHS})
endif ()

2
contrib/poco vendored

@ -1 +1 @@
Subproject commit 3df947389e6d9654919002797bdd86ed190b3963
Subproject commit d7a4383c4d85b51938b62ed5812bc0935245edb3

View File

@ -2,10 +2,10 @@
set(VERSION_REVISION 54407 CACHE STRING "")
set(VERSION_MAJOR 18 CACHE STRING "")
set(VERSION_MINOR 12 CACHE STRING "")
set(VERSION_PATCH 8 CACHE STRING "")
set(VERSION_GITHASH 199d8734f98fa7d04ebf2119431c5f56a7ed4e5a CACHE STRING "")
set(VERSION_DESCRIBE v18.12.8-testing CACHE STRING "")
set(VERSION_STRING 18.12.8 CACHE STRING "")
set(VERSION_PATCH 11 CACHE STRING "")
set(VERSION_GITHASH 1d28a9c510120b07f0719b2f33ccbc21be1e339d CACHE STRING "")
set(VERSION_DESCRIBE v18.12.11-testing CACHE STRING "")
set(VERSION_STRING 18.12.11 CACHE STRING "")
# end of autochange
set(VERSION_EXTRA "" CACHE STRING "")

View File

@ -37,7 +37,6 @@
#include <Interpreters/Context.h>
#include <Interpreters/Cluster.h>
#include <Interpreters/InterpreterFactory.h>
#include <Interpreters/InterpreterInsertQuery.h>
#include <Interpreters/InterpreterExistsQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterDropQuery.h>

View File

@ -862,9 +862,9 @@ class ModelFactory
public:
ModelPtr get(const IDataType & data_type, UInt64 seed, MarkovModelParameters markov_model_params) const
{
if (data_type.isInteger())
if (isInteger(data_type))
{
if (data_type.isUnsignedInteger())
if (isUnsignedInteger(data_type))
return std::make_unique<UnsignedIntegerModel>(seed);
else
return std::make_unique<SignedIntegerModel>(seed);

View File

@ -213,9 +213,7 @@ void HTTPHandler::processQuery(
Context context = server.context();
context.setGlobalContext(server.context());
/// It will forcibly detach query even if unexpected error ocurred and detachQuery() was not called
/// Normal detaching is happen in BlockIO callbacks
CurrentThread::QueryScope query_scope_holder(context);
CurrentThread::QueryScope query_scope(context);
LOG_TRACE(log, "Request URI: " << request.getURI());

View File

@ -130,6 +130,9 @@ void TCPHandler::runImpl()
Stopwatch watch;
state.reset();
/// Initialized later.
std::optional<CurrentThread::QueryScope> query_scope;
/** An exception during the execution of request (it must be sent over the network to the client).
* The client will be able to accept it, if it did not happen while sending another packet and the client has not disconnected yet.
*/
@ -152,7 +155,7 @@ void TCPHandler::runImpl()
if (!receivePacket())
continue;
CurrentThread::initializeQuery();
query_scope.emplace(query_context);
send_exception_with_stack_trace = query_context.getSettingsRef().calculate_text_stack_trace;
@ -197,6 +200,8 @@ void TCPHandler::runImpl()
sendLogs();
sendEndOfStream();
query_scope.reset();
state.reset();
}
catch (const Exception & e)
@ -265,9 +270,7 @@ void TCPHandler::runImpl()
try
{
/// It will forcibly detach query even if unexpected error ocсurred and detachQuery() was not called
CurrentThread::detachQueryIfNotDetached();
query_scope.reset();
state.reset();
}
catch (...)

View File

@ -79,7 +79,7 @@ public:
if (arguments.size() != 2)
throw Exception("Aggregate function " + getName() + " requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[1]->isUnsignedInteger())
if (!isUnsignedInteger(arguments[1]))
throw Exception("Second argument of aggregate function " + getName() + " must be integer.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
type = arguments.front();

View File

@ -61,10 +61,10 @@ public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: kind(kind_)
{
if (!arguments[0]->isNumber())
if (!isNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[1]->isNumber())
if (!isNumber(arguments[1]))
throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[0]->equals(*arguments[1]))

View File

@ -33,8 +33,6 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
const IAggregateFunction * function = func.get();
/** If the aggregate function returns an unfinalized/unfinished state,
* then you just need to copy pointers to it and also shared ownership of data.
*
@ -65,33 +63,73 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(function))
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->getData().assign(getData().begin(), getData().end());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = function->getReturnType()->createColumn();
res->reserve(getData().size());
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
for (auto val : getData())
function->insertResultInto(val, *res);
for (auto val : data)
func->insertResultInto(val, *res);
return res;
}
void ColumnAggregateFunction::ensureOwnership()
{
if (src)
{
/// We must copy all data from src and take ownership.
size_t size = data.size();
Arena & arena = createOrGetArena();
size_t size_of_state = func->sizeOfData();
size_t align_of_state = func->alignOfData();
size_t rollback_pos = 0;
try
{
for (size_t i = 0; i < size; ++i)
{
ConstAggregateDataPtr old_place = data[i];
data[i] = arena.alignedAlloc(size_of_state, align_of_state);
func->create(data[i]);
++rollback_pos;
func->merge(data[i], old_place, &arena);
}
}
catch (...)
{
/// If we failed to take ownership, destroy all temporary data.
if (!func->hasTrivialDestructor())
for (size_t i = 0; i < rollback_pos; ++i)
func->destroy(data[i]);
throw;
}
/// Now we own all data.
src.reset();
}
}
void ColumnAggregateFunction::insertRangeFrom(const IColumn & from, size_t start, size_t length)
{
const ColumnAggregateFunction & from_concrete = static_cast<const ColumnAggregateFunction &>(from);
if (start + length > from_concrete.getData().size())
if (start + length > from_concrete.data.size())
throw Exception("Parameters start = " + toString(start) + ", length = " + toString(length)
+ " are out of bound in ColumnAggregateFunction::insertRangeFrom method"
" (data.size() = "
+ toString(from_concrete.getData().size())
+ toString(from_concrete.data.size())
+ ").",
ErrorCodes::PARAMETER_OUT_OF_BOUND);
@ -112,14 +150,14 @@ void ColumnAggregateFunction::insertRangeFrom(const IColumn & from, size_t start
size_t old_size = data.size();
data.resize(old_size + length);
memcpy(&data[old_size], &from_concrete.getData()[start], length * sizeof(data[0]));
memcpy(&data[old_size], &from_concrete.data[start], length * sizeof(data[0]));
}
}
ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_size_hint) const
{
size_t size = getData().size();
size_t size = data.size();
if (size != filter.size())
throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
@ -127,14 +165,14 @@ ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_
return cloneEmpty();
auto res = createView();
auto & res_data = res->getData();
auto & res_data = res->data;
if (result_size_hint)
res_data.reserve(result_size_hint > 0 ? result_size_hint : size);
for (size_t i = 0; i < size; ++i)
if (filter[i])
res_data.push_back(getData()[i]);
res_data.push_back(data[i]);
/// To save RAM in case of too strong filtering.
if (res_data.size() * 2 < res_data.capacity())
@ -146,7 +184,7 @@ ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_
ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limit) const
{
size_t size = getData().size();
size_t size = data.size();
if (limit == 0)
limit = size;
@ -158,9 +196,9 @@ ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limi
auto res = createView();
res->getData().resize(limit);
res->data.resize(limit);
for (size_t i = 0; i < limit; ++i)
res->getData()[i] = getData()[perm[i]];
res->data[i] = data[perm[i]];
return res;
}
@ -175,9 +213,9 @@ ColumnPtr ColumnAggregateFunction::indexImpl(const PaddedPODArray<Type> & indexe
{
auto res = createView();
res->getData().resize(limit);
res->data.resize(limit);
for (size_t i = 0; i < limit; ++i)
res->getData()[i] = getData()[indexes[i]];
res->data[i] = data[indexes[i]];
return res;
}
@ -188,14 +226,14 @@ INSTANTIATE_INDEX_IMPL(ColumnAggregateFunction)
void ColumnAggregateFunction::updateHashWithValue(size_t n, SipHash & hash) const
{
WriteBufferFromOwnString wbuf;
func->serialize(getData()[n], wbuf);
func->serialize(data[n], wbuf);
hash.update(wbuf.str().c_str(), wbuf.str().size());
}
/// NOTE: Highly overestimates size of a column if it was produced in AggregatingBlockInputStream (it contains size of other columns)
size_t ColumnAggregateFunction::byteSize() const
{
size_t res = getData().size() * sizeof(getData()[0]);
size_t res = data.size() * sizeof(data[0]);
for (const auto & arena : arenas)
res += arena->size();
@ -207,7 +245,7 @@ size_t ColumnAggregateFunction::byteSize() const
/// Like byteSize(), highly overestimates size
size_t ColumnAggregateFunction::allocatedBytes() const
{
size_t res = getData().allocated_bytes();
size_t res = data.allocated_bytes();
for (const auto & arena : arenas)
res += arena->size();
@ -225,7 +263,7 @@ Field ColumnAggregateFunction::operator[](size_t n) const
Field field = String();
{
WriteBufferFromString buffer(field.get<String &>());
func->serialize(getData()[n], buffer);
func->serialize(data[n], buffer);
}
return field;
}
@ -235,18 +273,19 @@ void ColumnAggregateFunction::get(size_t n, Field & res) const
res = String();
{
WriteBufferFromString buffer(res.get<String &>());
func->serialize(getData()[n], buffer);
func->serialize(data[n], buffer);
}
}
StringRef ColumnAggregateFunction::getDataAt(size_t n) const
{
return StringRef(reinterpret_cast<const char *>(&getData()[n]), sizeof(getData()[n]));
return StringRef(reinterpret_cast<const char *>(&data[n]), sizeof(data[n]));
}
void ColumnAggregateFunction::insertData(const char * pos, size_t /*length*/)
{
getData().push_back(*reinterpret_cast<const AggregateDataPtr *>(pos));
ensureOwnership();
data.push_back(*reinterpret_cast<const AggregateDataPtr *>(pos));
}
void ColumnAggregateFunction::insertFrom(const IColumn & from, size_t n)
@ -254,24 +293,26 @@ void ColumnAggregateFunction::insertFrom(const IColumn & from, size_t n)
/// Must create new state of aggregate function and take ownership of it,
/// because ownership of states of aggregate function cannot be shared for individual rows,
/// (only as a whole, see comment above).
ensureOwnership();
insertDefault();
insertMergeFrom(from, n);
}
void ColumnAggregateFunction::insertFrom(ConstAggregateDataPtr place)
{
ensureOwnership();
insertDefault();
insertMergeFrom(place);
}
void ColumnAggregateFunction::insertMergeFrom(ConstAggregateDataPtr place)
{
func->merge(getData().back(), place, &createOrGetArena());
func->merge(data.back(), place, &createOrGetArena());
}
void ColumnAggregateFunction::insertMergeFrom(const IColumn & from, size_t n)
{
insertMergeFrom(static_cast<const ColumnAggregateFunction &>(from).getData()[n]);
insertMergeFrom(static_cast<const ColumnAggregateFunction &>(from).data[n]);
}
Arena & ColumnAggregateFunction::createOrGetArena()
@ -281,47 +322,54 @@ Arena & ColumnAggregateFunction::createOrGetArena()
return *arenas.back().get();
}
static void pushBackAndCreateState(ColumnAggregateFunction::Container & data, Arena & arena, IAggregateFunction * func)
{
data.push_back(arena.alignedAlloc(func->sizeOfData(), func->alignOfData()));
try
{
func->create(data.back());
}
catch (...)
{
data.pop_back();
throw;
}
}
void ColumnAggregateFunction::insert(const Field & x)
{
IAggregateFunction * function = func.get();
ensureOwnership();
Arena & arena = createOrGetArena();
getData().push_back(arena.alignedAlloc(function->sizeOfData(), function->alignOfData()));
function->create(getData().back());
pushBackAndCreateState(data, arena, func.get());
ReadBufferFromString read_buffer(x.get<const String &>());
function->deserialize(getData().back(), read_buffer, &arena);
func->deserialize(data.back(), read_buffer, &arena);
}
void ColumnAggregateFunction::insertDefault()
{
IAggregateFunction * function = func.get();
ensureOwnership();
Arena & arena = createOrGetArena();
getData().push_back(arena.alignedAlloc(function->sizeOfData(), function->alignOfData()));
function->create(getData().back());
pushBackAndCreateState(data, arena, func.get());
}
StringRef ColumnAggregateFunction::serializeValueIntoArena(size_t n, Arena & dst, const char *& begin) const
{
IAggregateFunction * function = func.get();
WriteBufferFromArena out(dst, begin);
function->serialize(getData()[n], out);
func->serialize(data[n], out);
return out.finish();
}
const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char * src_arena)
{
IAggregateFunction * function = func.get();
ensureOwnership();
/** Parameter "src_arena" points to Arena, from which we will deserialize the state.
* And "dst_arena" is another Arena, that aggregate function state will use to store its data.
*/
Arena & dst_arena = createOrGetArena();
getData().push_back(dst_arena.alignedAlloc(function->sizeOfData(), function->alignOfData()));
function->create(getData().back());
pushBackAndCreateState(data, dst_arena, func.get());
/** We will read from src_arena.
* There is no limit for reading - it is assumed, that we can read all that we need after src_arena pointer.
@ -331,7 +379,7 @@ const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char *
* Probably this will not work under UBSan.
*/
ReadBufferFromMemory read_buffer(src_arena, std::numeric_limits<char *>::max() - src_arena);
function->deserialize(getData().back(), read_buffer, &dst_arena);
func->deserialize(data.back(), read_buffer, &dst_arena);
return read_buffer.position();
}
@ -358,7 +406,7 @@ ColumnPtr ColumnAggregateFunction::replicate(const IColumn::Offsets & offsets) c
return cloneEmpty();
auto res = createView();
auto & res_data = res->getData();
auto & res_data = res->data;
res_data.reserve(offsets.back());
IColumn::Offset prev_offset = 0;
@ -399,7 +447,7 @@ MutableColumns ColumnAggregateFunction::scatter(IColumn::ColumnIndex num_columns
void ColumnAggregateFunction::getPermutation(bool /*reverse*/, size_t /*limit*/, int /*nan_direction_hint*/, IColumn::Permutation & res) const
{
size_t s = getData().size();
size_t s = data.size();
res.resize(s);
for (size_t i = 0; i < s; ++i)
res[i] = i;

View File

@ -74,6 +74,11 @@ private:
return res;
}
/// If we have another column as a source (owner of data), copy all data to ourself and reset source.
/// This is needed before inserting new elements, because we must own these elements (to destroy them in destructor),
/// but ownership of different elements cannot be mixed by different columns.
void ensureOwnership();
ColumnAggregateFunction(const AggregateFunctionPtr & func_)
: func(func_)
{

View File

@ -393,6 +393,7 @@ namespace ErrorCodes
extern const int REPLICA_STATUS_CHANGED = 416;
extern const int EXPECTED_ALL_OR_ANY = 417;
extern const int UNKNOWN_JOIN_STRICTNESS = 418;
extern const int CANNOT_ADD_DIFFERENT_AGGREGATE_STATES = 419;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -317,7 +317,7 @@ Columns Block::getColumns() const
}
MutableColumns Block::mutateColumns() const
MutableColumns Block::mutateColumns()
{
size_t num_columns = data.size();
MutableColumns columns(num_columns);

View File

@ -115,7 +115,7 @@ public:
MutableColumns cloneEmptyColumns() const;
/** Get columns from block for mutation. Columns in block will be nullptr. */
MutableColumns mutateColumns() const;
MutableColumns mutateColumns();
/** Replace columns in a block */
void setColumns(MutableColumns && columns);

View File

@ -20,7 +20,7 @@ namespace ErrorCodes
}
///
inline bool allowDecimalComparison(const IDataType * left_type, const IDataType * right_type)
inline bool allowDecimalComparison(const DataTypePtr & left_type, const DataTypePtr & right_type)
{
if (isDecimal(left_type))
{

View File

@ -59,7 +59,7 @@ template <> struct TypeName<String> { static const char * get() { return "Strin
enum class TypeIndex
{
None = 0,
Nothing = 0,
UInt8,
UInt16,
UInt32,
@ -84,6 +84,12 @@ enum class TypeIndex
UUID,
Array,
Tuple,
Set,
Interval,
Nullable,
Function,
AggregateFunction,
LowCardinality,
};
template <typename T> struct TypeId;

View File

@ -1,4 +1,5 @@
#include <DataStreams/RollupBlockInputStream.h>
#include <DataStreams/finalizeBlock.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/FilterDescription.h>
@ -7,22 +8,6 @@
namespace DB
{
static void finalize(Block & block)
{
for (size_t i = 0; i < block.columns(); ++i)
{
ColumnWithTypeAndName & current = block.getByPosition(i);
const DataTypeAggregateFunction * unfinalized_type = typeid_cast<const DataTypeAggregateFunction *>(current.type.get());
if (unfinalized_type)
{
current.type = unfinalized_type->getReturnType();
if (current.column)
current.column = typeid_cast<const ColumnAggregateFunction &>(*current.column).convertToValues();
}
}
}
RollupBlockInputStream::RollupBlockInputStream(
const BlockInputStreamPtr & input_, const Aggregator::Params & params_) : aggregator(params_),
keys(params_.keys)
@ -36,7 +21,7 @@ RollupBlockInputStream::RollupBlockInputStream(
Block RollupBlockInputStream::getHeader() const
{
Block res = children.at(0)->getHeader();
finalize(res);
finalizeBlock(res);
return res;
}
@ -58,7 +43,7 @@ Block RollupBlockInputStream::readImpl()
rollup_block = aggregator.mergeBlocks(rollup_blocks, false);
Block finalized = rollup_block;
finalize(finalized);
finalizeBlock(finalized);
return finalized;
}
@ -66,7 +51,7 @@ Block RollupBlockInputStream::readImpl()
current_key = keys.size() - 1;
rollup_block = block;
finalize(block);
finalizeBlock(block);
return block;
}

View File

@ -4,9 +4,9 @@
namespace DB
{
SquashingBlockInputStream::SquashingBlockInputStream(const BlockInputStreamPtr & src,
size_t min_block_size_rows, size_t min_block_size_bytes)
: transform(min_block_size_rows, min_block_size_bytes)
SquashingBlockInputStream::SquashingBlockInputStream(
const BlockInputStreamPtr & src, size_t min_block_size_rows, size_t min_block_size_bytes)
: header(src->getHeader()), transform(min_block_size_rows, min_block_size_bytes)
{
children.emplace_back(src);
}
@ -23,9 +23,13 @@ Block SquashingBlockInputStream::readImpl()
if (!block)
all_read = true;
SquashingTransform::Result result = transform.add(std::move(block));
SquashingTransform::Result result = transform.add(block.mutateColumns());
if (result.ready)
return result.block;
{
if (result.columns.empty())
return {};
return header.cloneWithColumns(std::move(result.columns));
}
}
}

View File

@ -16,12 +16,13 @@ public:
String getName() const override { return "Squashing"; }
Block getHeader() const override { return children.at(0)->getHeader(); }
Block getHeader() const override { return header; }
protected:
Block readImpl() override;
private:
Block header;
SquashingTransform transform;
bool all_read = false;
};

View File

@ -5,16 +5,16 @@ namespace DB
{
SquashingBlockOutputStream::SquashingBlockOutputStream(BlockOutputStreamPtr & dst, size_t min_block_size_rows, size_t min_block_size_bytes)
: output(dst), transform(min_block_size_rows, min_block_size_bytes)
: output(dst), header(output->getHeader()), transform(min_block_size_rows, min_block_size_bytes)
{
}
void SquashingBlockOutputStream::write(const Block & block)
{
SquashingTransform::Result result = transform.add(Block(block));
SquashingTransform::Result result = transform.add(Block(block).mutateColumns());
if (result.ready)
output->write(result.block);
output->write(header.cloneWithColumns(std::move(result.columns)));
}
@ -26,8 +26,8 @@ void SquashingBlockOutputStream::finalize()
all_written = true;
SquashingTransform::Result result = transform.add({});
if (result.ready && result.block)
output->write(result.block);
if (result.ready && !result.columns.empty())
output->write(header.cloneWithColumns(std::move(result.columns)));
}

View File

@ -14,7 +14,7 @@ class SquashingBlockOutputStream : public IBlockOutputStream
public:
SquashingBlockOutputStream(BlockOutputStreamPtr & dst, size_t min_block_size_rows, size_t min_block_size_bytes);
Block getHeader() const override { return output->getHeader(); }
Block getHeader() const override { return header; }
void write(const Block & block) override;
void flush() override;
@ -26,6 +26,7 @@ public:
private:
BlockOutputStreamPtr output;
Block header;
SquashingTransform transform;
bool all_written = false;

View File

@ -10,37 +10,38 @@ SquashingTransform::SquashingTransform(size_t min_block_size_rows, size_t min_bl
}
SquashingTransform::Result SquashingTransform::add(Block && block)
SquashingTransform::Result SquashingTransform::add(MutableColumns && columns)
{
if (!block)
return Result(std::move(accumulated_block));
/// End of input stream.
if (columns.empty())
return Result(std::move(accumulated_columns));
/// Just read block is alredy enough.
if (isEnoughSize(block.rows(), block.bytes()))
if (isEnoughSize(columns))
{
/// If no accumulated data, return just read block.
if (!accumulated_block)
return Result(std::move(block));
if (accumulated_columns.empty())
return Result(std::move(columns));
/// Return accumulated data (maybe it has small size) and place new block to accumulated data.
accumulated_block.swap(block);
return Result(std::move(block));
columns.swap(accumulated_columns);
return Result(std::move(columns));
}
/// Accumulated block is already enough.
if (accumulated_block && isEnoughSize(accumulated_block.rows(), accumulated_block.bytes()))
if (!accumulated_columns.empty() && isEnoughSize(accumulated_columns))
{
/// Return accumulated data and place new block to accumulated data.
accumulated_block.swap(block);
return Result(std::move(block));
columns.swap(accumulated_columns);
return Result(std::move(columns));
}
append(std::move(block));
append(std::move(columns));
if (isEnoughSize(accumulated_block.rows(), accumulated_block.bytes()))
if (isEnoughSize(accumulated_columns))
{
Block res;
res.swap(accumulated_block);
MutableColumns res;
res.swap(accumulated_columns);
return Result(std::move(res));
}
@ -49,23 +50,35 @@ SquashingTransform::Result SquashingTransform::add(Block && block)
}
void SquashingTransform::append(Block && block)
void SquashingTransform::append(MutableColumns && columns)
{
if (!accumulated_block)
if (accumulated_columns.empty())
{
accumulated_block = std::move(block);
accumulated_columns = std::move(columns);
return;
}
size_t columns = block.columns();
size_t rows = block.rows();
for (size_t i = 0; i < columns; ++i)
{
MutableColumnPtr mutable_column = (*std::move(accumulated_block.getByPosition(i).column)).mutate();
mutable_column->insertRangeFrom(*block.getByPosition(i).column, 0, rows);
accumulated_block.getByPosition(i).column = std::move(mutable_column);
for (size_t i = 0, size = columns.size(); i < size; ++i)
accumulated_columns[i]->insertRangeFrom(*columns[i], 0, columns[i]->size());
}
bool SquashingTransform::isEnoughSize(const MutableColumns & columns)
{
size_t rows = 0;
size_t bytes = 0;
for (const auto & column : columns)
{
if (!rows)
rows = column->size();
else if (rows != column->size())
throw Exception("Sizes of columns doesn't match", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
bytes += column->byteSize();
}
return isEnoughSize(rows, bytes);
}

View File

@ -29,25 +29,26 @@ public:
struct Result
{
bool ready = false;
Block block;
MutableColumns columns;
Result(bool ready_) : ready(ready_) {}
Result(Block && block_) : ready(true), block(std::move(block_)) {}
Result(MutableColumns && columns) : ready(true), columns(std::move(columns)) {}
};
/** Add next block and possibly returns squashed block.
* At end, you need to pass empty block. As the result for last (empty) block, you will get last Result with ready = true.
*/
Result add(Block && block);
Result add(MutableColumns && columns);
private:
size_t min_block_size_rows;
size_t min_block_size_bytes;
Block accumulated_block;
MutableColumns accumulated_columns;
void append(Block && block);
void append(MutableColumns && columns);
bool isEnoughSize(const MutableColumns & columns);
bool isEnoughSize(size_t rows, size_t bytes) const;
};

View File

@ -76,7 +76,7 @@ SummingSortedBlockInputStream::SummingSortedBlockInputStream(
}
else
{
bool is_agg_func = checkDataType<DataTypeAggregateFunction>(column.type.get());
bool is_agg_func = WhichDataType(column.type).isAggregateFunction();
if (!column.type->isSummable() && !is_agg_func)
{
column_numbers_not_to_aggregate.push_back(i);
@ -273,7 +273,7 @@ Block SummingSortedBlockInputStream::readImpl()
for (auto & desc : columns_to_aggregate)
{
// Wrap aggregated columns in a tuple to match function signature
if (!desc.is_agg_func_type && checkDataType<DataTypeTuple>(desc.function->getReturnType().get()))
if (!desc.is_agg_func_type && isTuple(desc.function->getReturnType()))
{
size_t tuple_size = desc.column_numbers.size();
MutableColumns tuple_columns(tuple_size);
@ -292,7 +292,7 @@ Block SummingSortedBlockInputStream::readImpl()
/// Place aggregation results into block.
for (auto & desc : columns_to_aggregate)
{
if (!desc.is_agg_func_type && checkDataType<DataTypeTuple>(desc.function->getReturnType().get()))
if (!desc.is_agg_func_type && isTuple(desc.function->getReturnType()))
{
/// Unpack tuple into block.
size_t tuple_size = desc.column_numbers.size();

View File

@ -1,4 +1,5 @@
#include <DataStreams/TotalsHavingBlockInputStream.h>
#include <DataStreams/finalizeBlock.h>
#include <Interpreters/ExpressionActions.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
@ -53,23 +54,6 @@ TotalsHavingBlockInputStream::TotalsHavingBlockInputStream(
}
static void finalize(Block & block)
{
for (size_t i = 0; i < block.columns(); ++i)
{
ColumnWithTypeAndName & current = block.getByPosition(i);
const DataTypeAggregateFunction * unfinalized_type = typeid_cast<const DataTypeAggregateFunction *>(current.type.get());
if (unfinalized_type)
{
current.type = unfinalized_type->getReturnType();
if (current.column)
current.column = typeid_cast<const ColumnAggregateFunction &>(*current.column).convertToValues();
}
}
}
Block TotalsHavingBlockInputStream::getTotals()
{
if (!totals)
@ -87,7 +71,7 @@ Block TotalsHavingBlockInputStream::getTotals()
}
totals = children.at(0)->getHeader().cloneWithColumns(std::move(current_totals));
finalize(totals);
finalizeBlock(totals);
}
if (totals && expression)
@ -101,7 +85,7 @@ Block TotalsHavingBlockInputStream::getHeader() const
{
Block res = children.at(0)->getHeader();
if (final)
finalize(res);
finalizeBlock(res);
if (expression)
expression->execute(res);
return res;
@ -129,7 +113,7 @@ Block TotalsHavingBlockInputStream::readImpl()
finalized = block;
if (final)
finalize(finalized);
finalizeBlock(finalized);
total_keys += finalized.rows();

View File

@ -0,0 +1,24 @@
#include <DataStreams/finalizeBlock.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/typeid_cast.h>
namespace DB
{
void finalizeBlock(Block & block)
{
for (size_t i = 0; i < block.columns(); ++i)
{
ColumnWithTypeAndName & current = block.getByPosition(i);
const DataTypeAggregateFunction * unfinalized_type = typeid_cast<const DataTypeAggregateFunction *>(current.type.get());
if (unfinalized_type)
{
current.type = unfinalized_type->getReturnType();
if (current.column)
current.column = typeid_cast<const ColumnAggregateFunction &>(*current.column).convertToValues();
}
}
}
}

View File

@ -0,0 +1,9 @@
#pragma once
#include <Core/Block.h>
namespace DB
{
/// Converts aggregate function columns with non-finalized states to final values
void finalizeBlock(Block & block);
}

View File

@ -30,8 +30,8 @@ public:
AggregateFunctionPtr getFunction() const { return function; }
std::string getName() const override;
const char * getFamilyName() const override { return "AggregateFunction"; }
TypeIndex getTypeId() const override { return TypeIndex::AggregateFunction; }
bool canBeInsideNullable() const override { return false; }

View File

@ -23,7 +23,6 @@ public:
void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override;
bool canBeUsedAsVersion() const override { return true; }
bool isDateOrDateTime() const override { return true; }
bool canBeInsideNullable() const override { return true; }
bool equals(const IDataType & rhs) const override;

View File

@ -48,7 +48,6 @@ public:
void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override;
bool canBeUsedAsVersion() const override { return true; }
bool isDateOrDateTime() const override { return true; }
bool canBeInsideNullable() const override { return true; }
bool equals(const IDataType & rhs) const override;

View File

@ -30,7 +30,6 @@ public:
bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override { return true; }
bool haveMaximumSizeOfValue() const override { return true; }
bool isCategorial() const override { return true; }
bool isEnum() const override { return true; }
bool canBeInsideNullable() const override { return true; }
bool isComparable() const override { return true; }
};

View File

@ -77,7 +77,7 @@ public:
bool haveSubtypes() const override { return false; }
bool isComparable() const override { return true; }
bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override { return true; }
bool isFixedString() const override { return true; }
bool isValueUnambiguouslyRepresentedInFixedSizeContiguousMemoryRegion() const override { return true; }
bool haveMaximumSizeOfValue() const override { return true; }
size_t getSizeOfValueInMemory() const override { return n; }
bool isCategorial() const override { return true; }

View File

@ -24,6 +24,7 @@ public:
std::string getName() const override;
const char * getFamilyName() const override { return "Function"; }
TypeIndex getTypeId() const override { return TypeIndex::Function; }
const DataTypes & getArgumentTypes() const
{

View File

@ -55,6 +55,7 @@ public:
std::string getName() const override { return std::string("Interval") + kindToString(); }
const char * getFamilyName() const override { return "Interval"; }
TypeIndex getTypeId() const override { return TypeIndex::Interval; }
bool equals(const IDataType & rhs) const override;

View File

@ -16,6 +16,7 @@ public:
static constexpr bool is_parametric = false;
const char * getFamilyName() const override { return "Nothing"; }
TypeIndex getTypeId() const override { return TypeIndex::Nothing; }
MutableColumnPtr createColumn() const override;

View File

@ -16,6 +16,7 @@ public:
explicit DataTypeNullable(const DataTypePtr & nested_data_type_);
std::string getName() const override { return "Nullable(" + nested_data_type->getName() + ")"; }
const char * getFamilyName() const override { return "Nullable"; }
TypeIndex getTypeId() const override { return TypeIndex::Nullable; }
void enumerateStreams(const StreamCallback & callback, SubstreamPath & path) const override;

View File

@ -14,6 +14,7 @@ class DataTypeSet final : public IDataTypeDummy
public:
static constexpr bool is_parametric = true;
const char * getFamilyName() const override { return "Set"; }
TypeIndex getTypeId() const override { return TypeIndex::Set; }
bool equals(const IDataType & rhs) const override { return typeid(rhs) == typeid(*this); }
bool isParametric() const override { return true; }
};

View File

@ -59,7 +59,6 @@ public:
bool isComparable() const override { return true; }
bool canBeComparedWithCollation() const override { return true; }
bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override { return true; }
bool isString() const override { return true; }
bool isCategorial() const override { return true; }
bool canBeInsideNullable() const override { return true; }
};

View File

@ -42,9 +42,9 @@ DataTypeWithDictionary::DataTypeWithDictionary(DataTypePtr dictionary_type_)
if (dictionary_type->isNullable())
inner_type = static_cast<const DataTypeNullable &>(*dictionary_type).getNestedType();
if (!inner_type->isStringOrFixedString()
&& !inner_type->isDateOrDateTime()
&& !inner_type->isNumber())
if (!isStringOrFixedString(inner_type)
&& !isDateOrDateTime(inner_type)
&& !isNumber(inner_type))
throw Exception("DataTypeWithDictionary is supported only for numbers, strings, Date or DateTime, but got "
+ dictionary_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -769,15 +769,15 @@ MutableColumnUniquePtr DataTypeWithDictionary::createColumnUniqueImpl(const IDat
if (auto * nullable_type = typeid_cast<const DataTypeNullable *>(&keys_type))
type = nullable_type->getNestedType().get();
if (type->isString())
if (isString(type))
return creator((ColumnString *)(nullptr));
if (type->isFixedString())
if (isFixedString(type))
return creator((ColumnFixedString *)(nullptr));
if (typeid_cast<const DataTypeDate *>(type))
return creator((ColumnVector<UInt16> *)(nullptr));
if (typeid_cast<const DataTypeDateTime *>(type))
return creator((ColumnVector<UInt32> *)(nullptr));
if (type->isNumber())
if (isNumber(type))
{
MutableColumnUniquePtr column;
TypeListNumbers::forEach(CreateColumnVector(column, *type, creator));

View File

@ -20,6 +20,7 @@ public:
return "LowCardinality(" + dictionary_type->getName() + ")";
}
const char * getFamilyName() const override { return "LowCardinality"; }
TypeIndex getTypeId() const override { return TypeIndex::LowCardinality; }
void enumerateStreams(const StreamCallback & callback, SubstreamPath & path) const override;
@ -126,20 +127,13 @@ public:
bool isSummable() const override { return dictionary_type->isSummable(); }
bool canBeUsedInBitOperations() const override { return dictionary_type->canBeUsedInBitOperations(); }
bool canBeUsedInBooleanContext() const override { return dictionary_type->canBeUsedInBooleanContext(); }
bool isNumber() const override { return false; }
bool isInteger() const override { return false; }
bool isUnsignedInteger() const override { return false; }
bool isDateOrDateTime() const override { return false; }
bool isValueRepresentedByNumber() const override { return dictionary_type->isValueRepresentedByNumber(); }
bool isValueRepresentedByInteger() const override { return dictionary_type->isValueRepresentedByInteger(); }
bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override { return true; }
bool isString() const override { return false; }
bool isFixedString() const override { return false; }
bool haveMaximumSizeOfValue() const override { return dictionary_type->haveMaximumSizeOfValue(); }
size_t getMaximumSizeOfValueInMemory() const override { return dictionary_type->getMaximumSizeOfValueInMemory(); }
size_t getSizeOfValueInMemory() const override { return dictionary_type->getSizeOfValueInMemory(); }
bool isCategorial() const override { return false; }
bool isEnum() const override { return false; }
bool isNullable() const override { return false; }
bool onlyNull() const override { return false; }
bool withDictionary() const override { return true; }

View File

@ -15,10 +15,7 @@ class DataTypeNumber final : public DataTypeNumberBase<T>
bool canBeUsedAsVersion() const override { return true; }
bool isSummable() const override { return true; }
bool canBeUsedInBitOperations() const override { return true; }
bool isUnsignedInteger() const override { return isInteger() && std::is_unsigned_v<T>; }
bool canBeUsedInBooleanContext() const override { return true; }
bool isNumber() const override { return true; }
bool isInteger() const override { return std::is_integral_v<T>; }
bool canBeInsideNullable() const override { return true; }
};

View File

@ -46,7 +46,7 @@ public:
virtual const char * getFamilyName() const = 0;
/// Unique type number or zero
virtual TypeIndex getTypeId() const { return TypeIndex::None; }
virtual TypeIndex getTypeId() const = 0;
/** Binary serialization for range of values in column - for writing to disk/network, etc.
*
@ -342,17 +342,6 @@ public:
*/
virtual bool canBeUsedInBooleanContext() const { return false; }
/** Integers, floats, not Nullable. Not Enums. Not Date/DateTime.
*/
virtual bool isNumber() const { return false; }
/** Integers. Not Nullable. Not Enums. Not Date/DateTime.
*/
virtual bool isInteger() const { return false; }
virtual bool isUnsignedInteger() const { return false; }
virtual bool isDateOrDateTime() const { return false; }
/** Numbers, Enums, Date, DateTime. Not nullable.
*/
virtual bool isValueRepresentedByNumber() const { return false; }
@ -376,13 +365,9 @@ public:
virtual bool isValueUnambiguouslyRepresentedInFixedSizeContiguousMemoryRegion() const
{
return isValueRepresentedByNumber() || isFixedString();
return isValueRepresentedByNumber();
}
virtual bool isString() const { return false; }
virtual bool isFixedString() const { return false; }
virtual bool isStringOrFixedString() const { return isString() || isFixedString(); }
/** Example: numbers, Date, DateTime, FixedString, Enum... Nullable and Tuple of such types.
* Counterexamples: String, Array.
* It's Ok to return false for AggregateFunction despite the fact that some of them have fixed size state.
@ -401,8 +386,6 @@ public:
*/
virtual bool isCategorial() const { return false; }
virtual bool isEnum() const { return false; }
virtual bool isNullable() const { return false; }
/** Is this type can represent only NULL value? (It also implies isNullable)
@ -423,11 +406,20 @@ public:
};
struct DataTypeExtractor
/// Some sugar to check data type of IDataType
struct WhichDataType
{
TypeIndex idx;
DataTypeExtractor(const IDataType * data_type)
WhichDataType(const IDataType & data_type)
: idx(data_type.getTypeId())
{}
WhichDataType(const IDataType * data_type)
: idx(data_type->getTypeId())
{}
WhichDataType(const DataTypePtr & data_type)
: idx(data_type->getTypeId())
{}
@ -437,6 +429,7 @@ struct DataTypeExtractor
bool isUInt64() const { return idx == TypeIndex::UInt64; }
bool isUInt128() const { return idx == TypeIndex::UInt128; }
bool isUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64() || isUInt128(); }
bool isNativeUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64(); }
bool isInt8() const { return idx == TypeIndex::Int8; }
bool isInt16() const { return idx == TypeIndex::Int16; }
@ -444,6 +437,7 @@ struct DataTypeExtractor
bool isInt64() const { return idx == TypeIndex::Int64; }
bool isInt128() const { return idx == TypeIndex::Int128; }
bool isInt() const { return isInt8() || isInt16() || isInt32() || isInt64() || isInt128(); }
bool isNativeInt() const { return isInt8() || isInt16() || isInt32() || isInt64(); }
bool isDecimal32() const { return idx == TypeIndex::Decimal32; }
bool isDecimal64() const { return idx == TypeIndex::Decimal64; }
@ -469,27 +463,69 @@ struct DataTypeExtractor
bool isUUID() const { return idx == TypeIndex::UUID; }
bool isArray() const { return idx == TypeIndex::Array; }
bool isTuple() const { return idx == TypeIndex::Tuple; }
bool isSet() const { return idx == TypeIndex::Set; }
bool isInterval() const { return idx == TypeIndex::Interval; }
bool isNothing() const { return idx == TypeIndex::Nothing; }
bool isNullable() const { return idx == TypeIndex::Nullable; }
bool isFunction() const { return idx == TypeIndex::Function; }
bool isAggregateFunction() const { return idx == TypeIndex::AggregateFunction; }
};
/// IDataType helpers (alternative for IDataType virtual methods)
/// IDataType helpers (alternative for IDataType virtual methods with single point of truth)
inline bool isEnum(const IDataType * data_type)
inline bool isDateOrDateTime(const DataTypePtr & data_type) { return WhichDataType(data_type).isDateOrDateTime(); }
inline bool isEnum(const DataTypePtr & data_type) { return WhichDataType(data_type).isEnum(); }
inline bool isDecimal(const DataTypePtr & data_type) { return WhichDataType(data_type).isDecimal(); }
inline bool isTuple(const DataTypePtr & data_type) { return WhichDataType(data_type).isTuple(); }
inline bool isArray(const DataTypePtr & data_type) { return WhichDataType(data_type).isArray(); }
template <typename T>
inline bool isUnsignedInteger(const T & data_type)
{
return DataTypeExtractor(data_type).isEnum();
return WhichDataType(data_type).isUInt();
}
inline bool isDecimal(const IDataType * data_type)
template <typename T>
inline bool isInteger(const T & data_type)
{
return DataTypeExtractor(data_type).isDecimal();
}
inline bool isNotDecimalButComparableToDecimal(const IDataType * data_type)
{
DataTypeExtractor which(data_type);
WhichDataType which(data_type);
return which.isInt() || which.isUInt();
}
inline bool isCompilableType(const IDataType * data_type)
template <typename T>
inline bool isNumber(const T & data_type)
{
WhichDataType which(data_type);
return which.isInt() || which.isUInt() || which.isFloat();
}
template <typename T>
inline bool isString(const T & data_type)
{
return WhichDataType(data_type).isString();
}
template <typename T>
inline bool isFixedString(const T & data_type)
{
return WhichDataType(data_type).isFixedString();
}
template <typename T>
inline bool isStringOrFixedString(const T & data_type)
{
return WhichDataType(data_type).isStringOrFixedString();
}
inline bool isNotDecimalButComparableToDecimal(const DataTypePtr & data_type)
{
WhichDataType which(data_type);
return which.isInt() || which.isUInt();
}
inline bool isCompilableType(const DataTypePtr & data_type)
{
return data_type->isValueRepresentedByNumber() && !isDecimal(data_type);
}

View File

@ -66,6 +66,21 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDa
return nullptr;
}
static inline bool canBeNativeType(const IDataType & type)
{
if (auto * nullable = typeid_cast<const DataTypeNullable *>(&type))
return canBeNativeType(*nullable->getNestedType());
return typeIsEither<DataTypeInt8, DataTypeUInt8>(type)
|| typeIsEither<DataTypeInt16, DataTypeUInt16, DataTypeDate>(type)
|| typeIsEither<DataTypeInt32, DataTypeUInt32, DataTypeDateTime>(type)
|| typeIsEither<DataTypeInt64, DataTypeUInt64, DataTypeInterval>(type)
|| typeIsEither<DataTypeUUID>(type)
|| typeIsEither<DataTypeFloat32>(type)
|| typeIsEither<DataTypeFloat64>(type)
|| typeid_cast<const DataTypeFixedString *>(&type);
}
static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type)
{
return toNativeType(builder, *type);

View File

@ -213,7 +213,7 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth
for (const auto & type : types)
{
if (type->isFixedString())
if (isFixedString(type))
{
have_string = true;
if (!fixed_string_type)
@ -221,7 +221,7 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth
else if (!type->equals(*fixed_string_type))
return getNothingOrThrow(" because some of them are FixedStrings with different length");
}
else if (type->isString())
else if (isString(type))
have_string = true;
else
all_strings = false;
@ -243,7 +243,7 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth
for (const auto & type : types)
{
if (type->isDateOrDateTime())
if (isDateOrDateTime(type))
have_date_or_datetime = true;
else
all_date_or_datetime = false;

View File

@ -289,7 +289,7 @@ void ComplexKeyHashedDictionary::updateData()
auto stream = source_ptr->loadUpdatedAll();
stream->readPrefix();
while (const auto block = stream->read())
while (Block block = stream->read())
{
const auto saved_key_column_ptrs = ext::map<Columns>(ext::range(0, keys_size), [&](const size_t key_idx)
{

View File

@ -315,7 +315,7 @@ void FlatDictionary::updateData()
auto stream = source_ptr->loadUpdatedAll();
stream->readPrefix();
while (const auto block = stream->read())
while (Block block = stream->read())
{
const auto &saved_id_column = *saved_block->safeGetByPosition(0).column;
const auto &update_id_column = *block.safeGetByPosition(0).column;

View File

@ -307,7 +307,7 @@ void HashedDictionary::updateData()
auto stream = source_ptr->loadUpdatedAll();
stream->readPrefix();
while (const auto block = stream->read())
while (Block block = stream->read())
{
const auto &saved_id_column = *saved_block->safeGetByPosition(0).column;
const auto &update_id_column = *block.safeGetByPosition(0).column;

View File

@ -227,7 +227,7 @@ bool CSVRowInputStream::parseRowAndPrintDiagnosticInfo(MutableColumns & columns,
if (curr_position < prev_position)
throw Exception("Logical error: parsing is non-deterministic.", ErrorCodes::LOGICAL_ERROR);
if (data_types[i]->isNumber() || data_types[i]->isDateOrDateTime())
if (isNumber(data_types[i]) || isDateOrDateTime(data_types[i]))
{
/// An empty string instead of a value.
if (curr_position == prev_position)

View File

@ -195,7 +195,7 @@ bool TabSeparatedRowInputStream::parseRowAndPrintDiagnosticInfo(MutableColumns &
if (curr_position < prev_position)
throw Exception("Logical error: parsing is non-deterministic.", ErrorCodes::LOGICAL_ERROR);
if (data_types[i]->isNumber() || data_types[i]->isDateOrDateTime())
if (isNumber(data_types[i]) || isDateOrDateTime(data_types[i]))
{
/// An empty string instead of a value.
if (curr_position == prev_position)

View File

@ -32,46 +32,6 @@ generate_function_register(Arithmetic
FunctionIntExp10
)
generate_function_register(Array
FunctionArray
FunctionArrayElement
FunctionHas
FunctionIndexOf
FunctionCountEqual
FunctionArrayEnumerate
FunctionArrayEnumerateUniq
FunctionArrayEnumerateDense
FunctionArrayUniq
FunctionArrayDistinct
FunctionEmptyArrayUInt8
FunctionEmptyArrayUInt16
FunctionEmptyArrayUInt32
FunctionEmptyArrayUInt64
FunctionEmptyArrayInt8
FunctionEmptyArrayInt16
FunctionEmptyArrayInt32
FunctionEmptyArrayInt64
FunctionEmptyArrayFloat32
FunctionEmptyArrayFloat64
FunctionEmptyArrayDate
FunctionEmptyArrayDateTime
FunctionEmptyArrayString
FunctionEmptyArrayToSingle
FunctionRange
FunctionArrayReduce
FunctionArrayReverse
FunctionArrayConcat
FunctionArraySlice
FunctionArrayPushBack
FunctionArrayPushFront
FunctionArrayPopBack
FunctionArrayPopFront
FunctionArrayHasAll
FunctionArrayHasAny
FunctionArrayIntersect
FunctionArrayResize
)
generate_function_register(Projection
FunctionOneOrZero
FunctionProject

View File

@ -0,0 +1,59 @@
#include <cstring>
#include <Columns/ColumnString.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
template <bool negative = false>
struct EmptyImpl
{
/// If the function will return constant value for FixedString data type.
static constexpr auto is_fixed_to_constant = false;
static void vector(const ColumnString::Chars_t & /*data*/, const ColumnString::Offsets & offsets, PaddedPODArray<UInt8> & res)
{
size_t size = offsets.size();
ColumnString::Offset prev_offset = 1;
for (size_t i = 0; i < size; ++i)
{
res[i] = negative ^ (offsets[i] == prev_offset);
prev_offset = offsets[i] + 1;
}
}
/// Only make sense if is_fixed_to_constant.
static void vector_fixed_to_constant(const ColumnString::Chars_t & /*data*/, size_t /*n*/, UInt8 & /*res*/)
{
throw Exception("Logical error: 'vector_fixed_to_constant method' is called", ErrorCodes::LOGICAL_ERROR);
}
static void vector_fixed_to_vector(const ColumnString::Chars_t & data, size_t n, PaddedPODArray<UInt8> & res)
{
std::vector<char> empty_chars(n);
size_t size = data.size() / n;
for (size_t i = 0; i < size; ++i)
res[i] = negative ^ (0 == memcmp(&data[i * size], empty_chars.data(), n));
}
static void array(const ColumnString::Offsets & offsets, PaddedPODArray<UInt8> & res)
{
size_t size = offsets.size();
ColumnString::Offset prev_offset = 0;
for (size_t i = 0; i < size; ++i)
{
res[i] = negative ^ (offsets[i] == prev_offset);
prev_offset = offsets[i];
}
}
};
}

View File

@ -20,13 +20,6 @@ const Type * checkAndGetDataType(const IDataType * data_type)
return typeid_cast<const Type *>(data_type);
}
template <typename Type>
bool checkDataType(const IDataType * data_type)
{
return checkAndGetDataType<Type>(data_type);
}
template <typename Type>
const Type * checkAndGetColumn(const IColumn * column)
{

View File

@ -0,0 +1,88 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include <ext/range.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename Impl>
class FunctionNumericPredicate : public IFunction
{
public:
static constexpr auto name = Impl::name;
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionNumericPredicate>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 1;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isNumber(arguments.front()))
throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataTypeUInt8>();
}
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const auto in = block.getByPosition(arguments.front()).column.get();
if ( !execute<UInt8>(block, in, result)
&& !execute<UInt16>(block, in, result)
&& !execute<UInt32>(block, in, result)
&& !execute<UInt64>(block, in, result)
&& !execute<Int8>(block, in, result)
&& !execute<Int16>(block, in, result)
&& !execute<Int32>(block, in, result)
&& !execute<Int64>(block, in, result)
&& !execute<Float32>(block, in, result)
&& !execute<Float64>(block, in, result))
throw Exception{"Illegal column " + in->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
}
template <typename T>
bool execute(Block & block, const IColumn * in_untyped, const size_t result)
{
if (const auto in = checkAndGetColumn<ColumnVector<T>>(in_untyped))
{
const auto size = in->size();
auto out = ColumnUInt8::create(size);
const auto & in_data = in->getData();
auto & out_data = out->getData();
for (const auto i : ext::range(0, size))
out_data[i] = Impl::execute(in_data[i]);
block.getByPosition(result).column = std::move(out);
return true;
}
return false;
}
};
}

View File

@ -0,0 +1,139 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/GatherUtils/GatherUtils.h>
#include <Functions/GatherUtils/Sources.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnString.h>
namespace DB
{
using namespace GatherUtils;
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
struct NameStartsWith
{
static constexpr auto name = "startsWith";
};
struct NameEndsWith
{
static constexpr auto name = "endsWith";
};
template <typename Name>
class FunctionStartsEndsWith : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionStartsEndsWith>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 2;
}
bool useDefaultImplementationForConstants() const override
{
return true;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isStringOrFixedString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!isStringOrFixedString(arguments[1]))
throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeUInt8>();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
const IColumn * haystack_column = block.getByPosition(arguments[0]).column.get();
const IColumn * needle_column = block.getByPosition(arguments[1]).column.get();
auto col_res = ColumnVector<UInt8>::create();
typename ColumnVector<UInt8>::Container & vec_res = col_res->getData();
vec_res.resize(input_rows_count);
if (const ColumnString * haystack = checkAndGetColumn<ColumnString>(haystack_column))
dispatch<StringSource>(StringSource(*haystack), needle_column, vec_res);
else if (const ColumnFixedString * haystack = checkAndGetColumn<ColumnFixedString>(haystack_column))
dispatch<FixedStringSource>(FixedStringSource(*haystack), needle_column, vec_res);
else if (const ColumnConst * haystack = checkAndGetColumnConst<ColumnString>(haystack_column))
dispatch<ConstSource<StringSource>>(ConstSource<StringSource>(*haystack), needle_column, vec_res);
else if (const ColumnConst * haystack = checkAndGetColumnConst<ColumnFixedString>(haystack_column))
dispatch<ConstSource<FixedStringSource>>(ConstSource<FixedStringSource>(*haystack), needle_column, vec_res);
else
throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
block.getByPosition(result).column = std::move(col_res);
}
private:
template <typename HaystackSource>
void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const
{
if (const ColumnString * needle = checkAndGetColumn<ColumnString>(needle_column))
execute<HaystackSource, StringSource>(haystack_source, StringSource(*needle), res_data);
else if (const ColumnFixedString * needle = checkAndGetColumn<ColumnFixedString>(needle_column))
execute<HaystackSource, FixedStringSource>(haystack_source, FixedStringSource(*needle), res_data);
else if (const ColumnConst * needle = checkAndGetColumnConst<ColumnString>(needle_column))
execute<HaystackSource, ConstSource<StringSource>>(haystack_source, ConstSource<StringSource>(*needle), res_data);
else if (const ColumnConst * needle = checkAndGetColumnConst<ColumnFixedString>(needle_column))
execute<HaystackSource, ConstSource<FixedStringSource>>(haystack_source, ConstSource<FixedStringSource>(*needle), res_data);
else
throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
}
template <typename HaystackSource, typename NeedleSource>
static void execute(HaystackSource haystack_source, NeedleSource needle_source, PaddedPODArray<UInt8> & res_data)
{
size_t row_num = 0;
while (!haystack_source.isEnd())
{
auto haystack = haystack_source.getWhole();
auto needle = needle_source.getWhole();
if (needle.size > haystack.size)
{
res_data[row_num] = false;
}
else
{
if constexpr (std::is_same_v<Name, NameStartsWith>)
{
res_data[row_num] = StringRef(haystack.data, needle.size) == StringRef(needle.data, needle.size);
}
else /// endsWith
{
res_data[row_num] = StringRef(haystack.data + haystack.size - needle.size, needle.size) == StringRef(needle.data, needle.size);
}
}
haystack_source.next();
needle_source.next();
++row_num;
}
}
};
}

View File

@ -0,0 +1,101 @@
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnArray.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename Impl, typename Name, typename ResultType>
class FunctionStringOrArrayToT : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionStringOrArrayToT>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 1;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isStringOrFixedString(arguments[0])
&& !isArray(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeNumber<ResultType>>();
}
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const ColumnPtr column = block.getByPosition(arguments[0]).column;
if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()))
{
auto col_res = ColumnVector<ResultType>::create();
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
vec_res.resize(col->size());
Impl::vector(col->getChars(), col->getOffsets(), vec_res);
block.getByPosition(result).column = std::move(col_res);
}
else if (const ColumnFixedString * col = checkAndGetColumn<ColumnFixedString>(column.get()))
{
if (Impl::is_fixed_to_constant)
{
ResultType res = 0;
Impl::vector_fixed_to_constant(col->getChars(), col->getN(), res);
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(col->size(), toField(res));
}
else
{
auto col_res = ColumnVector<ResultType>::create();
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
vec_res.resize(col->size());
Impl::vector_fixed_to_vector(col->getChars(), col->getN(), vec_res);
block.getByPosition(result).column = std::move(col_res);
}
}
else if (const ColumnArray * col = checkAndGetColumn<ColumnArray>(column.get()))
{
auto col_res = ColumnVector<ResultType>::create();
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
vec_res.resize(col->size());
Impl::array(col->getOffsets(), vec_res);
block.getByPosition(result).column = std::move(col_res);
}
else
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
};
}

View File

@ -0,0 +1,76 @@
#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename Impl, typename Name, bool is_injective = false>
class FunctionStringToString : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionStringToString>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 1;
}
bool isInjective(const Block &) override
{
return is_injective;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isStringOrFixedString(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return arguments[0];
}
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const ColumnPtr column = block.getByPosition(arguments[0]).column;
if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()))
{
auto col_res = ColumnString::create();
Impl::vector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets());
block.getByPosition(result).column = std::move(col_res);
}
else if (const ColumnFixedString * col = checkAndGetColumn<ColumnFixedString>(column.get()))
{
auto col_res = ColumnFixedString::create(col->getN());
Impl::vector_fixed(col->getChars(), col->getN(), col_res->getChars());
block.getByPosition(result).column = std::move(col_res);
}
else
throw Exception(
"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
};
}

View File

@ -44,6 +44,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int TOO_LESS_ARGUMENTS_FOR_FUNCTION;
extern const int DECIMAL_OVERFLOW;
extern const int CANNOT_ADD_DIFFERENT_AGGREGATE_STATES;
}
@ -1130,6 +1131,123 @@ class FunctionBinaryArithmetic : public IFunction
return FunctionFactory::instance().get(function_name.str(), context);
}
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
return false;
WhichDataType which0(type0);
WhichDataType which1(type1);
return (which0.isAggregateFunction() && which1.isNativeUInt())
|| (which0.isNativeUInt() && which1.isAggregateFunction());
}
bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>)
return false;
WhichDataType which0(type0);
WhichDataType which1(type1);
return which0.isAggregateFunction() && which1.isAggregateFunction();
}
/// Multiply aggregation state by integer constant: by merging it with itself specified number of times.
void executeAggregateMultiply(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const
{
ColumnNumbers new_arguments = arguments;
if (WhichDataType(block.getByPosition(new_arguments[1]).type).isAggregateFunction())
std::swap(new_arguments[0], new_arguments[1]);
if (!block.getByPosition(new_arguments[1]).column->isColumnConst())
throw Exception{"Illegal column " + block.getByPosition(new_arguments[1]).column->getName()
+ " of argument of aggregation state multiply. Should be integer constant", ErrorCodes::ILLEGAL_COLUMN};
const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get());
IAggregateFunction * function = column->getAggregateFunction().get();
auto arena = std::make_shared<Arena>();
auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_to->reserve(input_rows_count);
auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_from->reserve(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
column_to->insertDefault();
column_from->insertFrom(column->getData()[i]);
}
auto & vec_to = column_to->getData();
auto & vec_from = column_from->getData();
UInt64 m = typeid_cast<const ColumnConst *>(block.getByPosition(new_arguments[1]).column.get())->getValue<UInt64>();
/// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations
/// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
while (m)
{
if (m % 2)
{
for (size_t i = 0; i < input_rows_count; ++i)
function->merge(vec_to[i], vec_from[i], arena.get());
--m;
}
else
{
for (size_t i = 0; i < input_rows_count; ++i)
function->merge(vec_from[i], vec_from[i], arena.get());
m /= 2;
}
}
block.getByPosition(result).column = std::move(column_to);
}
/// Merge two aggregation states together.
void executeAggregateAddition(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const
{
const ColumnAggregateFunction * columns[2];
for (size_t i = 0; i < 2; ++i)
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction());
column_to->reserve(input_rows_count);
for(size_t i = 0; i < input_rows_count; ++i)
{
column_to->insertFrom(columns[0]->getData()[i]);
column_to->insertMergeFrom(columns[1]->getData()[i]);
}
block.getByPosition(result).column = std::move(column_to);
}
void executeDateTimeIntervalPlusMinus(Block & block, const ColumnNumbers & arguments,
size_t result, size_t input_rows_count, const FunctionBuilderPtr & function_builder) const
{
ColumnNumbers new_arguments = arguments;
/// Interval argument must be second.
if (WhichDataType(block.getByPosition(arguments[0]).type).isInterval())
std::swap(new_arguments[0], new_arguments[1]);
/// Change interval argument type to its representation
Block new_block = block;
new_block.getByPosition(new_arguments[1]).type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
ColumnsWithTypeAndName new_arguments_with_type_and_name =
{new_block.getByPosition(new_arguments[0]), new_block.getByPosition(new_arguments[1])};
auto function = function_builder->build(new_arguments_with_type_and_name);
function->execute(new_block, new_arguments, result, input_rows_count);
block.getByPosition(result).column = new_block.getByPosition(result).column;
}
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionBinaryArithmetic>(context); }
@ -1151,11 +1269,21 @@ public:
/// Special case when multiply aggregate function state
if (isAggregateMultiply(arguments[0], arguments[1]))
{
if (checkDataType<DataTypeAggregateFunction>(arguments[0].get()))
if (WhichDataType(arguments[0]).isAggregateFunction())
return arguments[0];
return arguments[1];
}
/// Special case - addition of two aggregate functions states
if (isAggregateAddition(arguments[0], arguments[1]))
{
if (!arguments[0]->equals(*arguments[1]))
throw Exception("Cannot add aggregate states of different functions: "
+ arguments[0]->getName() + " and " + arguments[1]->getName(), ErrorCodes::CANNOT_ADD_DIFFERENT_AGGREGATE_STATES);
return arguments[0];
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
{
@ -1165,7 +1293,7 @@ public:
new_arguments[i].type = arguments[i];
/// Interval argument must be second.
if (checkDataType<DataTypeInterval>(new_arguments[0].type.get()))
if (WhichDataType(new_arguments[0].type).isInterval())
std::swap(new_arguments[0], new_arguments[1]);
/// Change interval argument to its representation
@ -1206,92 +1334,26 @@ public:
return type_res;
}
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
return false;
auto is_uint_type = [](const DataTypePtr & type)
{
return checkDataType<DataTypeUInt8>(type.get()) || checkDataType<DataTypeUInt16>(type.get())
|| checkDataType<DataTypeUInt32>(type.get()) || checkDataType<DataTypeUInt64>(type.get());
};
return ((checkDataType<DataTypeAggregateFunction>(type0.get()) && is_uint_type(type1))
|| (is_uint_type(type0) && checkDataType<DataTypeAggregateFunction>(type1.get())));
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
/// Special case when multiply aggregate function state
if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{
ColumnNumbers new_arguments = arguments;
if (checkDataType<DataTypeAggregateFunction>(block.getByPosition(new_arguments[1]).type.get()))
std::swap(new_arguments[0], new_arguments[1]);
const ColumnAggregateFunction * column = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(new_arguments[0]).column.get());
IAggregateFunction * function = column->getAggregateFunction().get();
auto arena = std::make_shared<Arena>();
auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_to->reserve(input_rows_count);
auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena));
column_from->reserve(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
column_to->insertDefault();
column_from->insertFrom(column->getData()[i]);
executeAggregateMultiply(block, arguments, result, input_rows_count);
return;
}
auto & vec_to = column_to->getData();
auto & vec_from = column_from->getData();
UInt64 m = block.getByPosition(new_arguments[1]).column->getUInt(0);
/// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations
/// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
while (m)
/// Special case - addition of two aggregate functions states
if (isAggregateAddition(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{
if (m % 2)
{
for (size_t i = 0; i < input_rows_count; ++i)
function->merge(vec_to[i], vec_from[i], arena.get());
--m;
}
else
{
for (size_t i = 0; i < input_rows_count; ++i)
function->merge(vec_from[i], vec_from[i], arena.get());
m /= 2;
}
}
block.getByPosition(result).column = std::move(column_to);
executeAggregateAddition(block, arguments, result, input_rows_count);
return;
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{
ColumnNumbers new_arguments = arguments;
/// Interval argument must be second.
if (checkDataType<DataTypeInterval>(block.getByPosition(arguments[0]).type.get()))
std::swap(new_arguments[0], new_arguments[1]);
/// Change interval argument type to its representation
Block new_block = block;
new_block.getByPosition(new_arguments[1]).type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
ColumnsWithTypeAndName new_arguments_with_type_and_name =
{new_block.getByPosition(new_arguments[0]), new_block.getByPosition(new_arguments[1])};
auto function = function_builder->build(new_arguments_with_type_and_name);
function->execute(new_block, new_arguments, result, input_rows_count);
block.getByPosition(result).column = new_block.getByPosition(result).column;
executeDateTimeIntervalPlusMinus(block, arguments, result, input_rows_count, function_builder);
return;
}
@ -1906,17 +1968,17 @@ public:
throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be at least 2.", ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION};
const auto first_arg = arguments.front().get();
const auto & first_arg = arguments.front();
if (!first_arg->isInteger())
if (!isInteger(first_arg))
throw Exception{"Illegal type " + first_arg->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
for (const auto i : ext::range(1, arguments.size()))
{
const auto pos_arg = arguments[i].get();
const auto & pos_arg = arguments[i];
if (!pos_arg->isUnsignedInteger())
if (!isUnsignedInteger(pos_arg))
throw Exception{"Illegal type " + pos_arg->getName() + " of " + toString(i) + " argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}

File diff suppressed because it is too large Load Diff

View File

@ -150,12 +150,12 @@ public:
", expected FixedString(" + toString(ipv6_bytes_length) + ")",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!checkDataType<DataTypeUInt8>(arguments[1].get()))
if (!WhichDataType(arguments[1]).isUInt8())
throw Exception("Illegal type " + arguments[1]->getName() +
" of argument 2 of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!checkDataType<DataTypeUInt8>(arguments[2].get()))
if (!WhichDataType(arguments[2]).isUInt8())
throw Exception("Illegal type " + arguments[2]->getName() +
" of argument 3 of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -266,7 +266,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -519,7 +519,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!checkDataType<DataTypeUInt32>(arguments[0].get()))
if (!WhichDataType(arguments[0]).isUInt32())
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName() + ", expected UInt32",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -579,7 +579,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -714,7 +714,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!checkDataType<DataTypeUInt64>(arguments[0].get()))
if (!WhichDataType(arguments[0]).isUInt64())
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName() + ", expected UInt64",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -843,7 +843,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1006,7 +1006,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
/// String or FixedString(36)
if (!arguments[0]->isString())
if (!isString(arguments[0]))
{
const auto ptr = checkAndGetDataType<DataTypeFixedString>(arguments[0].get());
if (!ptr || ptr->getN() != uuid_text_length)
@ -1151,13 +1151,11 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString()
&& !arguments[0]->isFixedString()
&& !arguments[0]->isDateOrDateTime()
&& !checkDataType<DataTypeUInt8>(arguments[0].get())
&& !checkDataType<DataTypeUInt16>(arguments[0].get())
&& !checkDataType<DataTypeUInt32>(arguments[0].get())
&& !checkDataType<DataTypeUInt64>(arguments[0].get()))
WhichDataType which(arguments[0]);
if (!which.isStringOrFixedString()
&& !which.isDateOrDateTime()
&& !which.isUInt())
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1370,7 +1368,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1460,7 +1458,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isInteger())
if (!isInteger(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1543,7 +1541,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isStringOrFixedString())
if (!isStringOrFixedString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -817,7 +817,7 @@ private:
const IColumn * column_number = left_is_num ? col_left_untyped : col_right_untyped;
const IDataType * number_type = left_is_num ? left_type.get() : right_type.get();
DataTypeExtractor which(number_type);
WhichDataType which(number_type);
const bool legal_types = which.isDateOrDateTime() || which.isEnum() || which.isUUID();
@ -1077,8 +1077,8 @@ public:
/// Get result types by argument types. If the function does not apply to these arguments, throw an exception.
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
DataTypeExtractor left(arguments[0].get());
DataTypeExtractor right(arguments[1].get());
WhichDataType left(arguments[0].get());
WhichDataType right(arguments[1].get());
const DataTypeTuple * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].get());
const DataTypeTuple * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].get());
@ -1159,9 +1159,9 @@ public:
{
executeTuple(block, result, col_with_type_and_name_left, col_with_type_and_name_right, input_rows_count);
}
else if (isDecimal(left_type.get()) || isDecimal(right_type.get()))
else if (isDecimal(left_type) || isDecimal(right_type))
{
if (!allowDecimalComparison(left_type.get(), right_type.get()))
if (!allowDecimalComparison(left_type, right_type))
throw Exception("No operation " + getName() + " between " + left_type->getName() + " and " + right_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1193,7 +1193,7 @@ public:
auto isFloatingPoint = &typeIsEither<DataTypeFloat32, DataTypeFloat64>;
if ((isBigInteger(*types[0]) && isFloatingPoint(*types[1])) || (isBigInteger(*types[1]) && isFloatingPoint(*types[0])))
return false; /// TODO: implement (double, int_N where N > double's mantissa width)
return isCompilableType(types[0].get()) && isCompilableType(types[1].get());
return isCompilableType(types[0]) && isCompilableType(types[1]);
}
llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override

View File

@ -1,8 +1,9 @@
#include <Functions/FunctionsConditional.h>
#include <Functions/FunctionsArray.h>
#include <Functions/FunctionsTransform.h>
#include <Functions/FunctionFactory.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnConst.h>
#include <DataTypes/getLeastSupertype.h>
#include <Interpreters/castColumn.h>
#include <vector>
@ -205,7 +206,7 @@ DataTypePtr FunctionMultiIf::getReturnTypeImpl(const DataTypes & args) const
nested_type = arg.get();
}
if (!checkDataType<DataTypeUInt8>(nested_type))
if (!WhichDataType(nested_type).isUInt8())
throw Exception{"Illegal type " + arg->getName() + " of argument (condition) "
"of function " + getName() + ". Must be UInt8.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -251,22 +252,15 @@ DataTypePtr FunctionCaseWithExpression::getReturnTypeImpl(const DataTypes & args
/// See the comments in executeImpl() to understand why we actually have to
/// get the return type of a transform function.
/// Get the return types of the arrays that we pass to the transform function.
ColumnsWithTypeAndName src_array_types;
ColumnsWithTypeAndName dst_array_types;
/// Get the types of the arrays that we pass to the transform function.
DataTypes src_array_types;
DataTypes dst_array_types;
for (size_t i = 1; i < (args.size() - 1); ++i)
{
if ((i % 2) != 0)
src_array_types.push_back({nullptr, args[i], {}});
else
dst_array_types.push_back({nullptr, args[i], {}});
}
for (size_t i = 1; i < args.size() - 1; ++i)
((i % 2) ? src_array_types : dst_array_types).push_back(args[i]);
FunctionArray fun_array{context};
DataTypePtr src_array_type = fun_array.getReturnType(src_array_types);
DataTypePtr dst_array_type = fun_array.getReturnType(dst_array_types);
DataTypePtr src_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(src_array_types));
DataTypePtr dst_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(dst_array_types));
/// Finally get the return type of the transform function.
FunctionTransform fun_transform;
@ -291,29 +285,31 @@ void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers
/// Create the arrays required by the transform function.
ColumnNumbers src_array_args;
ColumnsWithTypeAndName src_array_types;
ColumnsWithTypeAndName src_array_elems;
DataTypes src_array_types;
ColumnNumbers dst_array_args;
ColumnsWithTypeAndName dst_array_types;
ColumnsWithTypeAndName dst_array_elems;
DataTypes dst_array_types;
for (size_t i = 1; i < (args.size() - 1); ++i)
{
if ((i % 2) != 0)
if (i % 2)
{
src_array_args.push_back(args[i]);
src_array_types.push_back(block.getByPosition(args[i]));
src_array_elems.push_back(block.getByPosition(args[i]));
src_array_types.push_back(block.getByPosition(args[i]).type);
}
else
{
dst_array_args.push_back(args[i]);
dst_array_types.push_back(block.getByPosition(args[i]));
dst_array_elems.push_back(block.getByPosition(args[i]));
dst_array_types.push_back(block.getByPosition(args[i]).type);
}
}
FunctionArray fun_array{context};
DataTypePtr src_array_type = fun_array.getReturnType(src_array_types);
DataTypePtr dst_array_type = fun_array.getReturnType(dst_array_types);
DataTypePtr src_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(src_array_types));
DataTypePtr dst_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(dst_array_types));
Block temp_block = block;
@ -323,8 +319,10 @@ void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers
size_t dst_array_pos = temp_block.columns();
temp_block.insert({nullptr, dst_array_type, ""});
fun_array.execute(temp_block, src_array_args, src_array_pos, input_rows_count);
fun_array.execute(temp_block, dst_array_args, dst_array_pos, input_rows_count);
auto fun_array = FunctionFactory::instance().get("array", context);
fun_array->build(src_array_elems)->execute(temp_block, src_array_args, src_array_pos, input_rows_count);
fun_array->build(dst_array_elems)->execute(temp_block, dst_array_args, dst_array_pos, input_rows_count);
/// Execute transform.
FunctionTransform fun_transform;

View File

@ -122,7 +122,7 @@ public:
bool isCompilableImpl(const DataTypes & types) const override
{
for (const auto & type : types)
if (!isCompilableType(removeNullable(type).get()))
if (!isCompilableType(removeNullable(type)))
return false;
return true;
}
@ -895,7 +895,7 @@ public:
return makeNullable(getReturnTypeImpl({
removeNullable(arguments[0]), arguments[1], arguments[2]}));
if (!checkDataType<DataTypeUInt8>(arguments[0].get()))
if (!WhichDataType(arguments[0]).isUInt8())
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument (condition) of function if. Must be UInt8.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -106,7 +106,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isInteger())
if (!isInteger(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of the first argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -115,7 +115,7 @@ public:
+ ", got " + arguments[0]->getName(),
ErrorCodes::BAD_ARGUMENTS);
if (!arguments[1]->isInteger())
if (!isInteger(arguments[1]))
throw Exception("Illegal type " + arguments[1]->getName() + " of the second argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -178,21 +178,23 @@ private:
const IDataType * hash_type = block.getByPosition(arguments[0]).type.get();
auto res_col = ColumnVector<ResultType>::create();
if (checkDataType<DataTypeUInt8>(hash_type))
WhichDataType which(hash_type);
if (which.isUInt8())
executeType<UInt8>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeUInt16>(hash_type))
else if (which.isUInt16())
executeType<UInt16>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeUInt32>(hash_type))
else if (which.isUInt32())
executeType<UInt32>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeUInt64>(hash_type))
else if (which.isUInt64())
executeType<UInt64>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeInt8>(hash_type))
else if (which.isInt8())
executeType<Int8>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeInt16>(hash_type))
else if (which.isInt16())
executeType<Int16>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeInt32>(hash_type))
else if (which.isInt32())
executeType<Int32>(hash_col, num_buckets, res_col.get());
else if (checkDataType<DataTypeInt64>(hash_type))
else if (which.isInt64())
executeType<Int64>(hash_col, num_buckets, res_col.get());
else
throw Exception("Illegal type " + hash_type->getName() + " of the first argument of function " + getName(),

View File

@ -20,7 +20,7 @@ void throwExceptionForIncompletelyParsedValue(
else
message_buf << " at begin of string";
if (to_type.isNumber())
if (isNumber(to_type))
message_buf << ". Note: there are to" << to_type.getName() << "OrZero and to" << to_type.getName() << "OrNull functions, which returns zero/NULL instead of throwing exception.";
throw Exception(message_buf.str(), ErrorCodes::CANNOT_PARSE_TEXT);

View File

@ -802,7 +802,7 @@ public:
|| std::is_same_v<Name, NameToUnixTimestamp>;
if (!(to_date_or_time
|| (std::is_same_v<Name, NameToString> && checkDataType<DataTypeDateTime>(arguments[0].type.get()))))
|| (std::is_same_v<Name, NameToString> && WhichDataType(arguments[0].type).isDateTime())))
{
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 1.",
@ -950,7 +950,7 @@ public:
+ toString(arguments.size()) + ", should be 1 or 2. Second argument (time zone) is optional only make sense for DateTime.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[0].type->isStringOrFixedString())
if (!isStringOrFixedString(arguments[0].type))
throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -963,7 +963,7 @@ public:
+ toString(arguments.size()) + ", should be 1. Second argument makes sense only when converting to DateTime.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[1].type->isString())
if (!isString(arguments[1].type))
throw Exception("Illegal type " + arguments[1].type->getName() + " of 2nd argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -1020,11 +1020,11 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!arguments[1].type->isUnsignedInteger())
if (!isUnsignedInteger(arguments[1].type))
throw Exception("Second argument for function " + getName() + " must be unsigned integer", ErrorCodes::ILLEGAL_COLUMN);
if (!arguments[1].column)
throw Exception("Second argument for function " + getName() + " must be constant", ErrorCodes::ILLEGAL_COLUMN);
if (!arguments[0].type->isStringOrFixedString())
if (!isStringOrFixedString(arguments[0].type))
throw Exception(getName() + " is only implemented for types String and FixedString", ErrorCodes::NOT_IMPLEMENTED);
const size_t n = arguments[1].column->getUInt(0);
@ -1140,8 +1140,8 @@ struct ToIntMonotonicity
}
/// If type is same, too. (Enum has separate case, because it is different data type)
if (checkDataType<DataTypeNumber<T>>(&type) ||
checkDataType<DataTypeEnum<T>>(&type))
if (checkAndGetDataType<DataTypeNumber<T>>(&type) ||
checkAndGetDataType<DataTypeEnum<T>>(&type))
return { true, true, true };
/// In other cases, if range is unbounded, we don't know, whether function is monotonic or not.
@ -1149,8 +1149,7 @@ struct ToIntMonotonicity
return {};
/// If converting from float, for monotonicity, arguments must fit in range of result type.
if (checkDataType<DataTypeFloat32>(&type)
|| checkDataType<DataTypeFloat64>(&type))
if (WhichDataType(type).isFloat())
{
Float64 left_float = left.get<Float64>();
Float64 right_float = right.get<Float64>();
@ -1460,7 +1459,7 @@ private:
static WrapperType createFixedStringWrapper(const DataTypePtr & from_type, const size_t N)
{
if (!from_type->isStringOrFixedString())
if (!isStringOrFixedString(from_type))
throw Exception{"CAST AS FixedString is only implemented for types String and FixedString", ErrorCodes::NOT_IMPLEMENTED};
return [N] (Block & block, const ColumnNumbers & arguments, const size_t result, size_t /*input_rows_count*/)
@ -1469,6 +1468,24 @@ private:
};
}
WrapperType createUUIDWrapper(const DataTypePtr & from_type, const DataTypeUUID * const, bool requested_result_is_nullable) const
{
if (requested_result_is_nullable)
throw Exception{"CAST AS Nullable(UUID) is not implemented", ErrorCodes::NOT_IMPLEMENTED};
FunctionPtr function = FunctionTo<DataTypeUUID>::Type::create(context);
/// Check conversion using underlying function
{
function->getReturnType(ColumnsWithTypeAndName(1, { nullptr, from_type, "" }));
}
return [function] (Block & block, const ColumnNumbers & arguments, const size_t result, size_t input_rows_count)
{
function->execute(block, arguments, result, input_rows_count);
};
}
template <typename FieldType>
WrapperType createDecimalWrapper(const DataTypePtr & from_type, const DataTypeDecimal<FieldType> * to_type) const
{
@ -1628,7 +1645,7 @@ private:
return createStringToEnumWrapper<ColumnString, EnumType>();
else if (checkAndGetDataType<DataTypeFixedString>(from_type.get()))
return createStringToEnumWrapper<ColumnFixedString, EnumType>();
else if (from_type->isNumber() || from_type->isEnum())
else if (isNumber(from_type) || isEnum(from_type))
{
auto function = Function::create(context);
@ -1878,7 +1895,7 @@ private:
{
if (from_type->equals(*to_type))
return createIdentityWrapper(from_type);
else if (checkDataType<DataTypeNothing>(from_type.get()))
else if (WhichDataType(from_type).isNothing())
return createNothingWrapper(to_type.get());
WrapperType ret;
@ -1920,6 +1937,14 @@ private:
ret = createDecimalWrapper(from_type, checkAndGetDataType<ToDataType>(to_type.get()));
return true;
}
if constexpr (std::is_same_v<ToDataType, DataTypeUUID>)
{
if (isStringOrFixedString(from_type))
{
ret = createUUIDWrapper(from_type, checkAndGetDataType<ToDataType>(to_type.get()), requested_result_is_nullable);
return true;
}
}
return false;
};
@ -2027,7 +2052,7 @@ private:
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeString>(to_type))
return monotonicityForType(type);
else if (from_type->isEnum())
else if (isEnum(from_type))
{
if (const auto type = checkAndGetDataType<DataTypeEnum8>(to_type))
return monotonicityForType(type);

View File

@ -637,14 +637,14 @@ public:
{
if (arguments.size() == 1)
{
if (!arguments[0].type->isDateOrDateTime())
if (!isDateOrDateTime(arguments[0].type))
throw Exception("Illegal type " + arguments[0].type->getName() + " of argument of function " + getName() +
". Should be a date or a date with time", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
else if (arguments.size() == 2)
{
if (!checkDataType<DataTypeDateTime>(arguments[0].type.get())
|| !checkDataType<DataTypeString>(arguments[1].type.get()))
if (!WhichDataType(arguments[0].type).isDateTime()
|| !WhichDataType(arguments[1].type).isString())
throw Exception(
"Function " + getName() + " supports 1 or 2 arguments. The 1st argument "
"must be of type Date or DateTime. The 2nd argument (optional) must be "
@ -670,10 +670,11 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
const IDataType * from_type = block.getByPosition(arguments[0]).type.get();
WhichDataType which(from_type);
if (checkDataType<DataTypeDate>(from_type))
if (which.isDate())
DateTimeTransformImpl<DataTypeDate::FieldType, typename ToDataType::FieldType, Transform>::execute(block, arguments, result, input_rows_count);
else if (checkDataType<DataTypeDateTime>(from_type))
else if (which.isDateTime())
DateTimeTransformImpl<DataTypeDateTime::FieldType, typename ToDataType::FieldType, Transform>::execute(block, arguments, result, input_rows_count);
else
throw Exception("Illegal type " + block.getByPosition(arguments[0]).type->getName() + " of argument of function " + getName(),
@ -945,20 +946,20 @@ public:
+ toString(arguments.size()) + ", should be 2 or 3",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[1].type->isNumber())
if (!isNumber(arguments[1].type))
throw Exception("Second argument for function " + getName() + " (delta) must be number",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (arguments.size() == 2)
{
if (!arguments[0].type->isDateOrDateTime())
if (!isDateOrDateTime(arguments[0].type))
throw Exception{"Illegal type " + arguments[0].type->getName() + " of argument of function " + getName() +
". Should be a date or a date with time", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
else
{
if (!checkDataType<DataTypeDateTime>(arguments[0].type.get())
|| !checkDataType<DataTypeString>(arguments[2].type.get()))
if (!WhichDataType(arguments[0].type).isDateTime()
|| !WhichDataType(arguments[2].type).isString())
throw Exception(
"Function " + getName() + " supports 2 or 3 arguments. The 1st argument "
"must be of type Date or DateTime. The 2nd argument must be number. "
@ -968,7 +969,7 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
if (checkDataType<DataTypeDate>(arguments[0].type.get()))
if (WhichDataType(arguments[0].type).isDate())
{
if (std::is_same_v<decltype(Transform::execute(DataTypeDate::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>)
return std::make_shared<DataTypeDate>();
@ -990,10 +991,11 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const IDataType * from_type = block.getByPosition(arguments[0]).type.get();
WhichDataType which(from_type);
if (checkDataType<DataTypeDate>(from_type))
if (which.isDate())
DateTimeAddIntervalImpl<DataTypeDate::FieldType, Transform>::execute(block, arguments, result);
else if (checkDataType<DataTypeDateTime>(from_type))
else if (which.isDateTime())
DateTimeAddIntervalImpl<DataTypeDateTime::FieldType, Transform>::execute(block, arguments, result);
else
throw Exception("Illegal type " + block.getByPosition(arguments[0]).type->getName() + " of argument of function " + getName(),
@ -1032,19 +1034,19 @@ public:
+ toString(arguments.size()) + ", should be 3 or 4",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("First argument for function " + getName() + " (unit) must be String",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[1]->isDateOrDateTime())
if (!isDateOrDateTime(arguments[1]))
throw Exception("Second argument for function " + getName() + " must be Date or DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[2]->isDateOrDateTime())
if (!isDateOrDateTime(arguments[2]))
throw Exception("Third argument for function " + getName() + " must be Date or DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (arguments.size() == 4 && !arguments[3]->isString())
if (arguments.size() == 4 && !isString(arguments[3]))
throw Exception("Fourth argument for function " + getName() + " (timezone) must be String",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1296,7 +1298,7 @@ public:
+ toString(arguments.size()) + ", should be 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!checkDataType<DataTypeDateTime>(arguments[0].type.get()))
if (!WhichDataType(arguments[0].type).isDateTime())
throw Exception{"Illegal type " + arguments[0].type->getName() + " of argument of function " + getName() +
". Should be DateTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -1326,7 +1328,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!checkDataType<DataTypeDateTime>(arguments[0].get()))
if (!WhichDataType(arguments[0]).isDateTime())
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be DateTime.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -1453,11 +1455,11 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!checkDataType<DataTypeDateTime>(arguments[0].get()))
if (!WhichDataType(arguments[0]).isDateTime())
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be DateTime.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!checkDataType<DataTypeUInt32>(arguments[1].get()))
if (!WhichDataType(arguments[1]).isUInt32())
throw Exception("Illegal type " + arguments[1]->getName() + " of second argument of function " + getName() + ". Must be UInt32.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -82,12 +82,12 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[1].get()) &&
!checkDataType<DataTypeTuple>(arguments[1].get()))
if (!WhichDataType(arguments[1]).isUInt64() &&
!isTuple(arguments[1]))
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -230,27 +230,27 @@ private:
throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 3 or 4.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
if (!arguments[0]->isString())
if (!isString(arguments[0]))
{
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
if (!arguments[1]->isString())
if (!isString(arguments[1]))
{
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
if (!checkDataType<DataTypeUInt64>(arguments[2].get()) &&
!checkDataType<DataTypeTuple>(arguments[2].get()))
if (!WhichDataType(arguments[2]).isUInt64() &&
!isTuple(arguments[2]))
{
throw Exception{"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
/// This is for the case of range dictionaries.
if (arguments.size() == 4 && !checkDataType<DataTypeDate>(arguments[3].get()))
if (arguments.size() == 4 && !WhichDataType(arguments[3]).isDate())
{
throw Exception{"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName()
+ ", must be Date.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -467,22 +467,22 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() +
", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[1]->isString())
if (!isString(arguments[1]))
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName() +
", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[2].get()) &&
!checkDataType<DataTypeTuple>(arguments[2].get()))
if (!WhichDataType(arguments[2]).isUInt64() &&
!isTuple(arguments[2]))
{
throw Exception{"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
if (!arguments[3]->isString())
if (!isString(arguments[3]))
throw Exception{"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName() +
", must be String.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -735,20 +735,20 @@ private:
if (arguments.size() != 3 && arguments.size() != 4)
throw Exception{"Function " + getName() + " takes 3 or 4 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[1]->isString())
if (!isString(arguments[1]))
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[2].get()) &&
!checkDataType<DataTypeTuple>(arguments[2].get()))
if (!WhichDataType(arguments[2]).isUInt64() &&
!isTuple(arguments[2]))
throw Exception{"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (arguments.size() == 4 && !checkDataType<DataTypeDate>(arguments[3].get()))
if (arguments.size() == 4 && !WhichDataType(arguments[3]).isDate())
throw Exception{"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName()
+ ", must be Date.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -1010,20 +1010,20 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[1]->isString())
if (!isString(arguments[1]))
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[2].get()) &&
!checkDataType<DataTypeTuple>(arguments[2].get()))
if (!WhichDataType(arguments[2]).isUInt64() &&
!isTuple(arguments[2]))
throw Exception{"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64 or tuple(...).", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataType>(arguments[3].get()))
if (!checkAndGetDataType<DataType>(arguments[3].get()))
throw Exception{"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName()
+ ", must be " + String(DataType{}.getFamilyName()) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -1252,11 +1252,11 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[1].get()))
if (!WhichDataType(arguments[1]).isUInt64())
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", must be UInt64.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -1408,15 +1408,15 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[1].get()))
if (!WhichDataType(arguments[1]).isUInt64())
throw Exception{"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", must be UInt64.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!checkDataType<DataTypeUInt64>(arguments[2].get()))
if (!WhichDataType(arguments[2]).isUInt64())
throw Exception{"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -32,7 +32,7 @@ DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const DataTypes & arguments
throw Exception("Function " + getName() + " expects at least 2 arguments",
ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION);
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -116,7 +116,7 @@ public:
const auto type_x = arguments[0];
if (!type_x->isNumber())
if (!isNumber(type_x))
throw Exception{"Unsupported type " + type_x->getName() + " of first argument of function " + getName() + " must be a numeric type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -44,9 +44,9 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const IDataType * type = arguments[0].get();
const DataTypePtr & type = arguments[0];
if (!type->isInteger())
if (!isInteger(type))
throw Exception("Cannot format " + type->getName() + " as bitmask string", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeString>();
@ -139,7 +139,7 @@ public:
{
const IDataType & type = *arguments[0];
if (!type.isNumber())
if (!isNumber(type))
throw Exception("Cannot format " + type.getName() + " as size in bytes", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeString>();

View File

@ -133,7 +133,7 @@ public:
for (auto j : ext::range(0, elements.size()))
{
if (!elements[j]->isNumber())
if (!isNumber(elements[j]))
{
throw Exception(getMsgPrefix(i) + " must contains numeric tuple at position " + toString(j + 1),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -162,7 +162,7 @@ public:
const Columns & tuple_columns = tuple_col->getColumns();
const DataTypes & tuple_types = typeid_cast<const DataTypeTuple &>(*block.getByPosition(arguments[0]).type).getElements();
bool use_float64 = checkDataType<DataTypeFloat64>(tuple_types[0].get()) || checkDataType<DataTypeFloat64>(tuple_types[1].get());
bool use_float64 = WhichDataType(tuple_types[0]).isFloat64() || WhichDataType(tuple_types[1]).isFloat64();
auto & result_column = block.safeGetByPosition(result).column;

View File

@ -61,7 +61,7 @@ private:
for (const auto arg_idx : ext::range(0, arguments.size()))
{
const auto arg = arguments[arg_idx].get();
if (!checkDataType<DataTypeFloat64>(arg))
if (!WhichDataType(arg).isFloat64())
throw Exception(
"Illegal type " + arg->getName() + " of argument " + std::to_string(arg_idx + 1) + " of function " + getName() + ". Must be Float64",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -213,7 +213,7 @@ private:
for (const auto arg_idx : ext::range(0, arguments.size()))
{
const auto arg = arguments[arg_idx].get();
if (!checkDataType<DataTypeFloat64>(arg))
if (!WhichDataType(arg).isFloat64())
{
throw Exception(
"Illegal type " + arg->getName() + " of argument " + std::to_string(arg_idx + 1) + " of function " + getName() + ". Must be Float64",

View File

@ -200,7 +200,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -303,17 +303,18 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const IDataType * from_type = block.getByPosition(arguments[0]).type.get();
WhichDataType which(from_type);
if (checkDataType<DataTypeUInt8>(from_type)) executeType<UInt8>(block, arguments, result);
else if (checkDataType<DataTypeUInt16>(from_type)) executeType<UInt16>(block, arguments, result);
else if (checkDataType<DataTypeUInt32>(from_type)) executeType<UInt32>(block, arguments, result);
else if (checkDataType<DataTypeUInt64>(from_type)) executeType<UInt64>(block, arguments, result);
else if (checkDataType<DataTypeInt8>(from_type)) executeType<Int8>(block, arguments, result);
else if (checkDataType<DataTypeInt16>(from_type)) executeType<Int16>(block, arguments, result);
else if (checkDataType<DataTypeInt32>(from_type)) executeType<Int32>(block, arguments, result);
else if (checkDataType<DataTypeInt64>(from_type)) executeType<Int64>(block, arguments, result);
else if (checkDataType<DataTypeDate>(from_type)) executeType<UInt16>(block, arguments, result);
else if (checkDataType<DataTypeDateTime>(from_type)) executeType<UInt32>(block, arguments, result);
if (which.isUInt8()) executeType<UInt8>(block, arguments, result);
else if (which.isUInt16()) executeType<UInt16>(block, arguments, result);
else if (which.isUInt32()) executeType<UInt32>(block, arguments, result);
else if (which.isUInt64()) executeType<UInt64>(block, arguments, result);
else if (which.isInt8()) executeType<Int8>(block, arguments, result);
else if (which.isInt16()) executeType<Int16>(block, arguments, result);
else if (which.isInt32()) executeType<Int32>(block, arguments, result);
else if (which.isInt64()) executeType<Int64>(block, arguments, result);
else if (which.isDate()) executeType<UInt16>(block, arguments, result);
else if (which.isDateTime()) executeType<UInt32>(block, arguments, result);
else
throw Exception("Illegal type " + block.getByPosition(arguments[0]).type->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -479,23 +480,25 @@ private:
template <bool first>
void executeAny(const IDataType * from_type, const IColumn * icolumn, ColumnUInt64::Container & vec_to)
{
if (checkDataType<DataTypeUInt8>(from_type)) executeIntType<UInt8, first>(icolumn, vec_to);
else if (checkDataType<DataTypeUInt16>(from_type)) executeIntType<UInt16, first>(icolumn, vec_to);
else if (checkDataType<DataTypeUInt32>(from_type)) executeIntType<UInt32, first>(icolumn, vec_to);
else if (checkDataType<DataTypeUInt64>(from_type)) executeIntType<UInt64, first>(icolumn, vec_to);
else if (checkDataType<DataTypeInt8>(from_type)) executeIntType<Int8, first>(icolumn, vec_to);
else if (checkDataType<DataTypeInt16>(from_type)) executeIntType<Int16, first>(icolumn, vec_to);
else if (checkDataType<DataTypeInt32>(from_type)) executeIntType<Int32, first>(icolumn, vec_to);
else if (checkDataType<DataTypeInt64>(from_type)) executeIntType<Int64, first>(icolumn, vec_to);
else if (checkDataType<DataTypeEnum8>(from_type)) executeIntType<Int8, first>(icolumn, vec_to);
else if (checkDataType<DataTypeEnum16>(from_type)) executeIntType<Int16, first>(icolumn, vec_to);
else if (checkDataType<DataTypeDate>(from_type)) executeIntType<UInt16, first>(icolumn, vec_to);
else if (checkDataType<DataTypeDateTime>(from_type)) executeIntType<UInt32, first>(icolumn, vec_to);
else if (checkDataType<DataTypeFloat32>(from_type)) executeIntType<Float32, first>(icolumn, vec_to);
else if (checkDataType<DataTypeFloat64>(from_type)) executeIntType<Float64, first>(icolumn, vec_to);
else if (checkDataType<DataTypeString>(from_type)) executeString<first>(icolumn, vec_to);
else if (checkDataType<DataTypeFixedString>(from_type)) executeString<first>(icolumn, vec_to);
else if (checkDataType<DataTypeArray>(from_type)) executeArray<first>(from_type, icolumn, vec_to);
WhichDataType which(from_type);
if (which.isUInt8()) executeIntType<UInt8, first>(icolumn, vec_to);
else if (which.isUInt16()) executeIntType<UInt16, first>(icolumn, vec_to);
else if (which.isUInt32()) executeIntType<UInt32, first>(icolumn, vec_to);
else if (which.isUInt64()) executeIntType<UInt64, first>(icolumn, vec_to);
else if (which.isInt8()) executeIntType<Int8, first>(icolumn, vec_to);
else if (which.isInt16()) executeIntType<Int16, first>(icolumn, vec_to);
else if (which.isInt32()) executeIntType<Int32, first>(icolumn, vec_to);
else if (which.isInt64()) executeIntType<Int64, first>(icolumn, vec_to);
else if (which.isEnum8()) executeIntType<Int8, first>(icolumn, vec_to);
else if (which.isEnum16()) executeIntType<Int16, first>(icolumn, vec_to);
else if (which.isDate()) executeIntType<UInt16, first>(icolumn, vec_to);
else if (which.isDateTime()) executeIntType<UInt32, first>(icolumn, vec_to);
else if (which.isFloat32()) executeIntType<Float32, first>(icolumn, vec_to);
else if (which.isFloat64()) executeIntType<Float64, first>(icolumn, vec_to);
else if (which.isString()) executeString<first>(icolumn, vec_to);
else if (which.isFixedString()) executeString<first>(icolumn, vec_to);
else if (which.isArray()) executeArray<first>(from_type, icolumn, vec_to);
else
throw Exception("Unexpected type " + from_type->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -602,23 +605,23 @@ public:
const ColumnWithTypeAndName & col = block.getByPosition(arguments[0]);
const IDataType * from_type = col.type.get();
const IColumn * icolumn = col.column.get();
WhichDataType which(from_type);
if (checkDataType<DataTypeUInt8>(from_type)) executeIntType<UInt8>(icolumn, vec_to);
else if (checkDataType<DataTypeUInt16>(from_type)) executeIntType<UInt16>(icolumn, vec_to);
else if (checkDataType<DataTypeUInt32>(from_type)) executeIntType<UInt32>(icolumn, vec_to);
else if (checkDataType<DataTypeUInt64>(from_type)) executeIntType<UInt64>(icolumn, vec_to);
else if (checkDataType<DataTypeInt8>(from_type)) executeIntType<Int8>(icolumn, vec_to);
else if (checkDataType<DataTypeInt16>(from_type)) executeIntType<Int16>(icolumn, vec_to);
else if (checkDataType<DataTypeInt32>(from_type)) executeIntType<Int32>(icolumn, vec_to);
else if (checkDataType<DataTypeInt64>(from_type)) executeIntType<Int64>(icolumn, vec_to);
else if (checkDataType<DataTypeEnum8>(from_type)) executeIntType<Int8>(icolumn, vec_to);
else if (checkDataType<DataTypeEnum16>(from_type)) executeIntType<Int16>(icolumn, vec_to);
else if (checkDataType<DataTypeDate>(from_type)) executeIntType<UInt16>(icolumn, vec_to);
else if (checkDataType<DataTypeDateTime>(from_type)) executeIntType<UInt32>(icolumn, vec_to);
else if (checkDataType<DataTypeFloat32>(from_type)) executeIntType<Float32>(icolumn, vec_to);
else if (checkDataType<DataTypeFloat64>(from_type)) executeIntType<Float64>(icolumn, vec_to);
else if (checkDataType<DataTypeString>(from_type)) executeString(icolumn, vec_to);
else if (checkDataType<DataTypeFixedString>(from_type)) executeString(icolumn, vec_to);
if (which.isUInt8()) executeIntType<UInt8>(icolumn, vec_to);
else if (which.isUInt16()) executeIntType<UInt16>(icolumn, vec_to);
else if (which.isUInt32()) executeIntType<UInt32>(icolumn, vec_to);
else if (which.isUInt64()) executeIntType<UInt64>(icolumn, vec_to);
else if (which.isInt8()) executeIntType<Int8>(icolumn, vec_to);
else if (which.isInt16()) executeIntType<Int16>(icolumn, vec_to);
else if (which.isInt32()) executeIntType<Int32>(icolumn, vec_to);
else if (which.isInt64()) executeIntType<Int64>(icolumn, vec_to);
else if (which.isEnum8()) executeIntType<Int8>(icolumn, vec_to);
else if (which.isEnum16()) executeIntType<Int16>(icolumn, vec_to);
else if (which.isDate()) executeIntType<UInt16>(icolumn, vec_to);
else if (which.isDateTime()) executeIntType<UInt32>(icolumn, vec_to);
else if (which.isFloat32()) executeIntType<Float32>(icolumn, vec_to);
else if (which.isFloat64()) executeIntType<Float64>(icolumn, vec_to);
else if (which.isStringOrFixedString()) executeString(icolumn, vec_to);
else
throw Exception("Unexpected type " + from_type->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -843,13 +846,13 @@ public:
toString(arg_count) + ", should be 1 or 2.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
const auto first_arg = arguments.front().get();
if (!checkDataType<DataTypeString>(first_arg))
if (!WhichDataType(first_arg).isString())
throw Exception{"Illegal type " + first_arg->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (arg_count == 2)
{
const auto second_arg = arguments.back().get();
if (!second_arg->isInteger())
const auto & second_arg = arguments.back();
if (!isInteger(second_arg))
throw Exception{"Illegal type " + second_arg->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}

View File

@ -310,20 +310,15 @@ struct ArraySumImpl
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{
if (checkDataType<DataTypeUInt8>(&*expression_return) ||
checkDataType<DataTypeUInt16>(&*expression_return) ||
checkDataType<DataTypeUInt32>(&*expression_return) ||
checkDataType<DataTypeUInt64>(&*expression_return))
WhichDataType which(expression_return);
if (which.isNativeUInt())
return std::make_shared<DataTypeUInt64>();
if (checkDataType<DataTypeInt8>(&*expression_return) ||
checkDataType<DataTypeInt16>(&*expression_return) ||
checkDataType<DataTypeInt32>(&*expression_return) ||
checkDataType<DataTypeInt64>(&*expression_return))
if (which.isNativeInt())
return std::make_shared<DataTypeInt64>();
if (checkDataType<DataTypeFloat32>(&*expression_return) ||
checkDataType<DataTypeFloat64>(&*expression_return))
if (which.isFloat())
return std::make_shared<DataTypeFloat64>();
throw Exception("arraySum cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -602,20 +597,15 @@ struct ArrayCumSumImpl
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{
if (checkDataType<DataTypeUInt8>(&*expression_return) ||
checkDataType<DataTypeUInt16>(&*expression_return) ||
checkDataType<DataTypeUInt32>(&*expression_return) ||
checkDataType<DataTypeUInt64>(&*expression_return))
WhichDataType which(expression_return);
if (which.isNativeUInt())
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
if (checkDataType<DataTypeInt8>(&*expression_return) ||
checkDataType<DataTypeInt16>(&*expression_return) ||
checkDataType<DataTypeInt32>(&*expression_return) ||
checkDataType<DataTypeInt64>(&*expression_return))
if (which.isNativeInt())
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt64>());
if (checkDataType<DataTypeFloat32>(&*expression_return) ||
checkDataType<DataTypeFloat64>(&*expression_return))
if (which.isFloat())
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
throw Exception("arrayCumSum cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -824,7 +814,7 @@ public:
DataTypePtr nested_type = array_type->getNestedType();
if (Impl::needBoolean() && !checkDataType<DataTypeUInt8>(&*nested_type))
if (Impl::needBoolean() && !WhichDataType(nested_type).isUInt8())
throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found "
+ arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -845,7 +835,7 @@ public:
/// The types of the remaining arguments are already checked in getLambdaArgumentTypes.
DataTypePtr return_type = data_type_function->getReturnType();
if (Impl::needBoolean() && !checkDataType<DataTypeUInt8>(&*return_type))
if (Impl::needBoolean() && !WhichDataType(return_type).isUInt8())
throw Exception("Expression for function " + getName() + " must return UInt8, found "
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -309,8 +309,8 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < arguments.size(); ++i)
if (!(arguments[i]->isNumber()
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || removeNullable(arguments[i])->isNumber()))))
if (!(isNumber(arguments[i])
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || isNumber(removeNullable(arguments[i]))))))
throw Exception("Illegal type ("
+ arguments[i]->getName()
+ ") of " + toString(i + 1) + " argument of function " + getName(),
@ -488,7 +488,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isNumber())
if (!isNumber(arguments[0]))
throw Exception("Illegal type ("
+ arguments[0]->getName()
+ ") of argument of function " + getName(),

View File

@ -78,7 +78,7 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments.front()->isNumber())
if (!isNumber(arguments.front()))
throw Exception{"Illegal type " + arguments.front()->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataTypeFloat64>();
@ -199,7 +199,7 @@ private:
{
const auto check_argument_type = [this] (const IDataType * arg)
{
if (!arg->isNumber())
if (!isNumber(arg))
throw Exception{"Illegal type " + arg->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
};

File diff suppressed because it is too large Load Diff

View File

@ -11,32 +11,6 @@
namespace DB
{
/** Creates an array, multiplying the column (the first argument) by the number of elements in the array (the second argument).
*/
class FunctionReplicate : public IFunction
{
public:
static constexpr auto name = "replicate";
static FunctionPtr create(const Context & context);
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 2;
}
bool useDefaultImplementationForNulls() const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
/// Executes expression. Uses for lambda functions implementation. Can't be created from factory.
class FunctionExpression : public IFunctionBase, public IPreparedFunction,
public std::enable_shared_from_this<FunctionExpression>

View File

@ -1,390 +0,0 @@
#include <Functions/FunctionsNull.h>
#include <Functions/FunctionsLogical.h>
#include <Functions/FunctionsConditional.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeNothing.h>
#include <Columns/ColumnNullable.h>
#include <cstdlib>
#include <string>
#include <memory>
namespace DB
{
void registerFunctionsNull(FunctionFactory & factory)
{
factory.registerFunction<FunctionIsNull>();
factory.registerFunction<FunctionIsNotNull>();
factory.registerFunction<FunctionCoalesce>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionIfNull>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionNullIf>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionAssumeNotNull>();
factory.registerFunction<FunctionToNullable>();
}
/// Implementation of isNull.
FunctionPtr FunctionIsNull::create(const Context &)
{
return std::make_shared<FunctionIsNull>();
}
std::string FunctionIsNull::getName() const
{
return name;
}
DataTypePtr FunctionIsNull::getReturnTypeImpl(const DataTypes &) const
{
return std::make_shared<DataTypeUInt8>();
}
void FunctionIsNull::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
{
const ColumnWithTypeAndName & elem = block.getByPosition(arguments[0]);
if (elem.column->isColumnNullable())
{
/// Merely return the embedded null map.
block.getByPosition(result).column = static_cast<const ColumnNullable &>(*elem.column).getNullMapColumnPtr();
}
else
{
/// Since no element is nullable, return a zero-constant column representing
/// a zero-filled null map.
block.getByPosition(result).column = DataTypeUInt8().createColumnConst(elem.column->size(), UInt64(0));
}
}
/// Implementation of isNotNull.
FunctionPtr FunctionIsNotNull::create(const Context &)
{
return std::make_shared<FunctionIsNotNull>();
}
std::string FunctionIsNotNull::getName() const
{
return name;
}
DataTypePtr FunctionIsNotNull::getReturnTypeImpl(const DataTypes &) const
{
return std::make_shared<DataTypeUInt8>();
}
void FunctionIsNotNull::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
Block temp_block
{
block.getByPosition(arguments[0]),
{
nullptr,
std::make_shared<DataTypeUInt8>(),
""
},
{
nullptr,
std::make_shared<DataTypeUInt8>(),
""
}
};
FunctionIsNull{}.execute(temp_block, {0}, 1, input_rows_count);
FunctionNot{}.execute(temp_block, {1}, 2, input_rows_count);
block.getByPosition(result).column = std::move(temp_block.getByPosition(2).column);
}
/// Implementation of coalesce.
FunctionPtr FunctionCoalesce::create(const Context & context)
{
return std::make_shared<FunctionCoalesce>(context);
}
std::string FunctionCoalesce::getName() const
{
return name;
}
DataTypePtr FunctionCoalesce::getReturnTypeImpl(const DataTypes & arguments) const
{
/// Skip all NULL arguments. If any argument is non-Nullable, skip all next arguments.
DataTypes filtered_args;
filtered_args.reserve(arguments.size());
for (const auto & arg : arguments)
{
if (arg->onlyNull())
continue;
filtered_args.push_back(arg);
if (!arg->isNullable())
break;
}
DataTypes new_args;
for (size_t i = 0; i < filtered_args.size(); ++i)
{
bool is_last = i + 1 == filtered_args.size();
if (is_last)
{
new_args.push_back(filtered_args[i]);
}
else
{
new_args.push_back(std::make_shared<DataTypeUInt8>());
new_args.push_back(removeNullable(filtered_args[i]));
}
}
if (new_args.empty())
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>());
if (new_args.size() == 1)
return new_args.front();
auto res = FunctionMultiIf{context}.getReturnTypeImpl(new_args);
/// if last argument is not nullable, result should be also not nullable
if (!new_args.back()->isNullable() && res->isNullable())
res = removeNullable(res);
return res;
}
void FunctionCoalesce::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
/// coalesce(arg0, arg1, ..., argN) is essentially
/// multiIf(isNotNull(arg0), assumeNotNull(arg0), isNotNull(arg1), assumeNotNull(arg1), ..., argN)
/// with constant NULL arguments removed.
ColumnNumbers filtered_args;
filtered_args.reserve(arguments.size());
for (const auto & arg : arguments)
{
const auto & type = block.getByPosition(arg).type;
if (type->onlyNull())
continue;
filtered_args.push_back(arg);
if (!type->isNullable())
break;
}
FunctionIsNotNull is_not_null;
FunctionAssumeNotNull assume_not_null;
ColumnNumbers multi_if_args;
Block temp_block = block;
for (size_t i = 0; i < filtered_args.size(); ++i)
{
size_t res_pos = temp_block.columns();
bool is_last = i + 1 == filtered_args.size();
if (is_last)
{
multi_if_args.push_back(filtered_args[i]);
}
else
{
temp_block.insert({nullptr, std::make_shared<DataTypeUInt8>(), ""});
is_not_null.execute(temp_block, {filtered_args[i]}, res_pos, input_rows_count);
temp_block.insert({nullptr, removeNullable(block.getByPosition(filtered_args[i]).type), ""});
assume_not_null.execute(temp_block, {filtered_args[i]}, res_pos + 1, input_rows_count);
multi_if_args.push_back(res_pos);
multi_if_args.push_back(res_pos + 1);
}
}
/// If all arguments appeared to be NULL.
if (multi_if_args.empty())
{
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConstWithDefaultValue(input_rows_count);
return;
}
if (multi_if_args.size() == 1)
{
block.getByPosition(result).column = block.getByPosition(multi_if_args.front()).column;
return;
}
FunctionMultiIf{context}.execute(temp_block, multi_if_args, result, input_rows_count);
ColumnPtr res = std::move(temp_block.getByPosition(result).column);
/// if last argument is not nullable, result should be also not nullable
if (!block.getByPosition(multi_if_args.back()).column->isColumnNullable() && res->isColumnNullable())
res = static_cast<const ColumnNullable &>(*res).getNestedColumnPtr();
block.getByPosition(result).column = std::move(res);
}
/// Implementation of ifNull.
FunctionPtr FunctionIfNull::create(const Context &)
{
return std::make_shared<FunctionIfNull>();
}
std::string FunctionIfNull::getName() const
{
return name;
}
DataTypePtr FunctionIfNull::getReturnTypeImpl(const DataTypes & arguments) const
{
if (arguments[0]->onlyNull())
return arguments[1];
if (!arguments[0]->isNullable())
return arguments[0];
return FunctionIf{}.getReturnTypeImpl({std::make_shared<DataTypeUInt8>(), removeNullable(arguments[0]), arguments[1]});
}
void FunctionIfNull::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
/// Always null.
if (block.getByPosition(arguments[0]).type->onlyNull())
{
block.getByPosition(result).column = block.getByPosition(arguments[1]).column;
return;
}
/// Could not contain nulls, so nullIf makes no sense.
if (!block.getByPosition(arguments[0]).type->isNullable())
{
block.getByPosition(result).column = block.getByPosition(arguments[0]).column;
return;
}
/// ifNull(col1, col2) == if(isNotNull(col1), assumeNotNull(col1), col2)
Block temp_block = block;
size_t is_not_null_pos = temp_block.columns();
temp_block.insert({nullptr, std::make_shared<DataTypeUInt8>(), ""});
size_t assume_not_null_pos = temp_block.columns();
temp_block.insert({nullptr, removeNullable(block.getByPosition(arguments[0]).type), ""});
FunctionIsNotNull{}.execute(temp_block, {arguments[0]}, is_not_null_pos, input_rows_count);
FunctionAssumeNotNull{}.execute(temp_block, {arguments[0]}, assume_not_null_pos, input_rows_count);
FunctionIf{}.execute(temp_block, {is_not_null_pos, assume_not_null_pos, arguments[1]}, result, input_rows_count);
block.getByPosition(result).column = std::move(temp_block.getByPosition(result).column);
}
/// Implementation of nullIf.
FunctionPtr FunctionNullIf::create(const Context & context)
{
return std::make_shared<FunctionNullIf>(context);
}
FunctionNullIf::FunctionNullIf(const Context & context) : context(context) {}
std::string FunctionNullIf::getName() const
{
return name;
}
DataTypePtr FunctionNullIf::getReturnTypeImpl(const DataTypes & arguments) const
{
return FunctionIf{}.getReturnTypeImpl({std::make_shared<DataTypeUInt8>(), makeNullable(arguments[0]), arguments[0]});
}
void FunctionNullIf::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
/// nullIf(col1, col2) == if(col1 == col2, NULL, col1)
Block temp_block = block;
size_t res_pos = temp_block.columns();
temp_block.insert({nullptr, std::make_shared<DataTypeUInt8>(), ""});
{
auto equals_func = FunctionFactory::instance().get("equals", context)->build(
{block.getByPosition(arguments[0]), block.getByPosition(arguments[1])});
equals_func->execute(temp_block, {arguments[0], arguments[1]}, res_pos, input_rows_count);
}
/// Argument corresponding to the NULL value.
size_t null_pos = temp_block.columns();
/// Append a NULL column.
ColumnWithTypeAndName null_elem;
null_elem.type = block.getByPosition(result).type;
null_elem.column = null_elem.type->createColumnConstWithDefaultValue(input_rows_count);
null_elem.name = "NULL";
temp_block.insert(null_elem);
FunctionIf{}.execute(temp_block, {res_pos, null_pos, arguments[0]}, result, input_rows_count);
block.getByPosition(result).column = std::move(temp_block.getByPosition(result).column);
}
/// Implementation of assumeNotNull.
FunctionPtr FunctionAssumeNotNull::create(const Context &)
{
return std::make_shared<FunctionAssumeNotNull>();
}
std::string FunctionAssumeNotNull::getName() const
{
return name;
}
DataTypePtr FunctionAssumeNotNull::getReturnTypeImpl(const DataTypes & arguments) const
{
return removeNullable(arguments[0]);
}
void FunctionAssumeNotNull::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
{
const ColumnPtr & col = block.getByPosition(arguments[0]).column;
ColumnPtr & res_col = block.getByPosition(result).column;
if (col->isColumnNullable())
{
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(*col);
res_col = nullable_col.getNestedColumnPtr();
}
else
res_col = col;
}
/// Implementation of toNullable.
FunctionPtr FunctionToNullable::create(const Context &)
{
return std::make_shared<FunctionToNullable>();
}
std::string FunctionToNullable::getName() const
{
return name;
}
DataTypePtr FunctionToNullable::getReturnTypeImpl(const DataTypes & arguments) const
{
return makeNullable(arguments[0]);
}
void FunctionToNullable::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
{
block.getByPosition(result).column = makeNullable(block.getByPosition(arguments[0]).column);
}
}

View File

@ -1,137 +0,0 @@
#pragma once
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/IDataType.h>
#include <Core/ColumnNumbers.h>
namespace DB
{
class Block;
class Context;
/// Implements the function isNull which returns true if a value
/// is null, false otherwise.
class FunctionIsNull : public IFunction
{
public:
static constexpr auto name = "isNull";
static FunctionPtr create(const Context & context);
std::string getName() const override;
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
/// Implements the function isNotNull which returns true if a value
/// is not null, false otherwise.
class FunctionIsNotNull : public IFunction
{
public:
static constexpr auto name = "isNotNull";
static FunctionPtr create(const Context & context);
std::string getName() const override;
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
/// Implements the function coalesce which takes a set of arguments and
/// returns the value of the leftmost non-null argument. If no such value is
/// found, coalesce() returns NULL.
class FunctionCoalesce : public IFunction
{
public:
static constexpr auto name = "coalesce";
static FunctionPtr create(const Context & context);
FunctionCoalesce(const Context & context) : context(context) {}
std::string getName() const override;
bool useDefaultImplementationForNulls() const override { return false; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
private:
const Context & context;
};
/// Implements the function ifNull which takes 2 arguments and returns
/// the value of the 1st argument if it is not null. Otherwise it returns
/// the value of the 2nd argument.
class FunctionIfNull : public IFunction
{
public:
static constexpr auto name = "ifNull";
static FunctionPtr create(const Context & context);
std::string getName() const override;
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
/// Implements the function nullIf which takes 2 arguments and returns
/// NULL if both arguments have the same value. Otherwise it returns the
/// value of the first argument.
class FunctionNullIf : public IFunction
{
private:
const Context & context;
public:
static constexpr auto name = "nullIf";
static FunctionPtr create(const Context & context);
FunctionNullIf(const Context & context);
std::string getName() const override;
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
/// Implements the function assumeNotNull which takes 1 argument and works as follows:
/// - if the argument is a nullable column, return its embedded column;
/// - otherwise return the original argument.
/// NOTE: assumeNotNull may not be called with the NULL value.
class FunctionAssumeNotNull : public IFunction
{
public:
static constexpr auto name = "assumeNotNull";
static FunctionPtr create(const Context & context);
std::string getName() const override;
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
/// If value is not Nullable or NULL, wraps it to Nullable.
class FunctionToNullable : public IFunction
{
public:
static constexpr auto name = "toNullable";
static FunctionPtr create(const Context & context);
std::string getName() const override;
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
};
}

View File

@ -167,7 +167,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const IDataType & type = *arguments[0];
if (!type.isStringOrFixedString())
if (!isStringOrFixedString(type))
throw Exception("Cannot reinterpret " + type.getName() + " as " + ToDataType().getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<ToDataType>();

View File

@ -578,7 +578,7 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (const auto & type : arguments)
if (!type->isNumber())
if (!isNumber(type))
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

File diff suppressed because it is too large Load Diff

View File

@ -1,204 +0,0 @@
#pragma once
#include <Poco/UTF8Encoding.h>
#include <Poco/Unicode.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnString.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}
/** String functions
*
* length, empty, notEmpty,
* concat, substring, lower, upper, reverse
* lengthUTF8, substringUTF8, lowerUTF8, upperUTF8, reverseUTF8
*
* s -> UInt8: empty, notEmpty
* s -> UInt64: length, lengthUTF8
* s -> s: lower, upper, lowerUTF8, upperUTF8, reverse, reverseUTF8
* s, s -> s: concat
* s, c1, c2 -> s: substring, substringUTF8
* s, c1, c2, s2 -> s: replace, replaceUTF8
*
* The search functions for strings and regular expressions are located separately.
* URL functions are located separately.
* String encoding functions, converting to other types are located separately.
*
* The functions length, empty, notEmpty, reverse also work with arrays.
*/
/// xor or do nothing
template <bool>
UInt8 xor_or_identity(const UInt8 c, const int mask)
{
return c ^ mask;
}
template <>
inline UInt8 xor_or_identity<false>(const UInt8 c, const int)
{
return c;
}
/// It is caller's responsibility to ensure the presence of a valid cyrillic sequence in array
template <bool to_lower>
inline void UTF8CyrillicToCase(const UInt8 *& src, UInt8 *& dst)
{
if (src[0] == 0xD0u && (src[1] >= 0x80u && src[1] <= 0x8Fu))
{
/// ЀЁЂЃЄЅІЇЈЉЊЋЌЍЎЏ
*dst++ = xor_or_identity<to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<to_lower>(*src++, 0x10);
}
else if (src[0] == 0xD1u && (src[1] >= 0x90u && src[1] <= 0x9Fu))
{
/// ѐёђѓєѕіїјљњћќѝўџ
*dst++ = xor_or_identity<!to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<!to_lower>(*src++, 0x10);
}
else if (src[0] == 0xD0u && (src[1] >= 0x90u && src[1] <= 0x9Fu))
{
/// А
*dst++ = *src++;
*dst++ = xor_or_identity<to_lower>(*src++, 0x20);
}
else if (src[0] == 0xD0u && (src[1] >= 0xB0u && src[1] <= 0xBFu))
{
/// а-п
*dst++ = *src++;
*dst++ = xor_or_identity<!to_lower>(*src++, 0x20);
}
else if (src[0] == 0xD0u && (src[1] >= 0xA0u && src[1] <= 0xAFu))
{
/// Р
*dst++ = xor_or_identity<to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<to_lower>(*src++, 0x20);
}
else if (src[0] == 0xD1u && (src[1] >= 0x80u && src[1] <= 0x8Fu))
{
/// р
*dst++ = xor_or_identity<!to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<!to_lower>(*src++, 0x20);
}
}
/** If the string contains UTF-8 encoded text, convert it to the lower (upper) case.
* Note: It is assumed that after the character is converted to another case,
* the length of its multibyte sequence in UTF-8 does not change.
* Otherwise, the behavior is undefined.
*/
template <char not_case_lower_bound,
char not_case_upper_bound,
int to_case(int),
void cyrillic_to_case(const UInt8 *&, UInt8 *&)>
struct LowerUpperUTF8Impl
{
static void vector(const ColumnString::Chars_t & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars_t & res_data,
ColumnString::Offsets & res_offsets);
static void vector_fixed(const ColumnString::Chars_t & data, size_t n, ColumnString::Chars_t & res_data);
static void constant(const std::string & data, std::string & res_data);
/** Converts a single code point starting at `src` to desired case, storing result starting at `dst`.
* `src` and `dst` are incremented by corresponding sequence lengths. */
static void toCase(const UInt8 *& src, const UInt8 * src_end, UInt8 *& dst);
private:
static constexpr auto ascii_upper_bound = '\x7f';
static constexpr auto flip_case_mask = 'A' ^ 'a';
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst);
};
template <typename Impl, typename Name, bool is_injective = false>
class FunctionStringToString : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionStringToString>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 1;
}
bool isInjective(const Block &) override
{
return is_injective;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isStringOrFixedString())
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return arguments[0];
}
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const ColumnPtr column = block.getByPosition(arguments[0]).column;
if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()))
{
auto col_res = ColumnString::create();
Impl::vector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets());
block.getByPosition(result).column = std::move(col_res);
}
else if (const ColumnFixedString * col = checkAndGetColumn<ColumnFixedString>(column.get()))
{
auto col_res = ColumnFixedString::create(col->getN());
Impl::vector_fixed(col->getChars(), col->getN(), col_res->getChars());
block.getByPosition(result).column = std::move(col_res);
}
else
throw Exception(
"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
};
struct NameLowerUTF8
{
static constexpr auto name = "lowerUTF8";
};
struct NameUpperUTF8
{
static constexpr auto name = "upperUTF8";
};
using FunctionLowerUTF8 = FunctionStringToString<LowerUpperUTF8Impl<'A', 'Z', Poco::Unicode::toLower, UTF8CyrillicToCase<true>>, NameLowerUTF8>;
using FunctionUpperUTF8 = FunctionStringToString<LowerUpperUTF8Impl<'a', 'z', Poco::Unicode::toUpper, UTF8CyrillicToCase<false>>, NameUpperUTF8>;
}

View File

@ -66,7 +66,7 @@ public:
/// Check the type of the function's arguments.
static void checkArguments(const DataTypes & arguments)
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -124,11 +124,11 @@ public:
static void checkArguments(const DataTypes & arguments)
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[1]->isString())
if (!isString(arguments[1]))
throw Exception("Illegal type " + arguments[1]->getName() + " of second argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -504,11 +504,11 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type || !array_type->getNestedType()->isString())
if (!array_type || !isString(array_type->getNestedType()))
throw Exception("First argument for function " + getName() + " must be array of strings.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (arguments.size() == 2
&& !arguments[1]->isString())
&& !isString(arguments[1]))
throw Exception("Second argument for function " + getName() + " must be constant string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeString>();

View File

@ -954,15 +954,15 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isStringOrFixedString())
if (!isStringOrFixedString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[1]->isStringOrFixedString())
if (!isStringOrFixedString(arguments[1]))
throw Exception("Illegal type " + arguments[1]->getName() + " of second argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[2]->isStringOrFixedString())
if (!isStringOrFixedString(arguments[2]))
throw Exception("Illegal type " + arguments[2]->getName() + " of third argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -65,11 +65,11 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[1]->isString())
if (!isString(arguments[1]))
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -149,11 +149,11 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!arguments[1]->isString())
if (!isString(arguments[1]))
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -77,7 +77,7 @@ public:
const DataTypePtr & type_x = arguments[0];
if (!type_x->isValueRepresentedByNumber() && !type_x->isString())
if (!type_x->isValueRepresentedByNumber() && !isString(type_x))
throw Exception{"Unsupported type " + type_x->getName()
+ " of first argument of function " + getName()
+ ", must be numeric type or Date/DateTime or String", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -91,7 +91,7 @@ public:
const auto type_arr_from_nested = type_arr_from->getNestedType();
if ((type_x->isValueRepresentedByNumber() != type_arr_from_nested->isValueRepresentedByNumber())
|| (!!type_x->isString() != !!type_arr_from_nested->isString()))
|| (isString(type_x) != isString(type_arr_from_nested)))
{
throw Exception{"First argument and elements of array of second argument of function " + getName()
+ " must have compatible types: both numeric or both strings.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -108,7 +108,7 @@ public:
if (args_size == 3)
{
if ((type_x->isValueRepresentedByNumber() != type_arr_to_nested->isValueRepresentedByNumber())
|| (!!type_x->isString() != !!checkDataType<DataTypeString>(type_arr_to_nested.get())))
|| (isString(type_x) != isString(type_arr_to_nested)))
throw Exception{"Function " + getName()
+ " has signature: transform(T, Array(T), Array(U), U) -> U; or transform(T, Array(T), Array(T)) -> T; where T and U are types.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -119,13 +119,16 @@ public:
{
const DataTypePtr & type_default = arguments[3];
if (!type_default->isValueRepresentedByNumber() && !type_default->isString())
if (!type_default->isValueRepresentedByNumber() && !isString(type_default))
throw Exception{"Unsupported type " + type_default->getName()
+ " of fourth argument (default value) of function " + getName()
+ ", must be numeric type or Date/DateTime or String", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
bool default_is_string = WhichDataType(type_default).isString();
bool nested_is_string = WhichDataType(type_arr_to_nested).isString();
if ((type_default->isValueRepresentedByNumber() != type_arr_to_nested->isValueRepresentedByNumber())
|| (!!checkDataType<DataTypeString>(type_default.get()) != !!checkDataType<DataTypeString>(type_arr_to_nested.get())))
|| (default_is_string != nested_is_string))
throw Exception{"Function " + getName()
+ " have signature: transform(T, Array(T), Array(U), U) -> U; or transform(T, Array(T), Array(T)) -> T; where T and U are types.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -7,7 +7,7 @@
#include <common/find_first_symbols.h>
#include <common/StringRef.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionsString.h>
#include <Functions/FunctionStringToString.h>
#include <Functions/FunctionsStringArray.h>
#include <port/memrchr.h>
@ -584,7 +584,7 @@ public:
static void checkArguments(const DataTypes & arguments)
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -669,7 +669,7 @@ public:
static void checkArguments(const DataTypes & arguments)
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -746,7 +746,7 @@ public:
static void checkArguments(const DataTypes & arguments)
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -841,7 +841,7 @@ public:
static void checkArguments(const DataTypes & arguments)
{
if (!arguments[0]->isString())
if (!isString(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() + ". Must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

View File

@ -1,7 +1,6 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsVisitParam.h>
#include <Functions/FunctionsStringSearch.h>
#include <Functions/FunctionsString.h>
#include <Functions/FunctionsURL.h>
@ -26,7 +25,6 @@ using FunctionVisitParamExtractRaw = FunctionsStringSearchToString<ExtractParamT
using FunctionVisitParamExtractString = FunctionsStringSearchToString<ExtractParamToStringImpl<ExtractString>, NameVisitParamExtractString>;
void registerFunctionsVisitParam(FunctionFactory & factory)
{
factory.registerFunction<FunctionVisitParamHas>();

View File

@ -37,15 +37,12 @@ std::unique_ptr<IArraySink> createArraySink(ColumnArray & col, size_t column_siz
void concat(const std::vector<std::unique_ptr<IArraySource>> & sources, IArraySink & sink);
void sliceFromLeftConstantOffsetUnbounded(IArraySource & src, IArraySink & sink, size_t offset);
void sliceFromLeftConstantOffsetBounded(IArraySource & src, IArraySink & sink, size_t offset, ssize_t length);
void sliceFromRightConstantOffsetUnbounded(IArraySource & src, IArraySink & sink, size_t offset);
void sliceFromRightConstantOffsetBounded(IArraySource & src, IArraySink & sink, size_t offset, ssize_t length);
void sliceDynamicOffsetUnbounded(IArraySource & src, IArraySink & sink, const IColumn & offset_column);
void sliceDynamicOffsetBounded(IArraySource & src, IArraySink & sink, const IColumn & offset_column, const IColumn & length_column);
void sliceHas(IArraySource & first, IArraySource & second, bool all, ColumnUInt8 & result);
@ -55,5 +52,6 @@ void push(IArraySource & array_source, IValueSource & value_source, IArraySink &
void resizeDynamicSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, const IColumn & size_column);
void resizeConstantSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, ssize_t size);
}

View File

@ -0,0 +1,67 @@
#include <Columns/ColumnString.h>
namespace DB
{
template <char not_case_lower_bound, char not_case_upper_bound>
struct LowerUpperImpl
{
static void vector(const ColumnString::Chars_t & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars_t & res_data,
ColumnString::Offsets & res_offsets)
{
res_data.resize(data.size());
res_offsets.assign(offsets);
array(data.data(), data.data() + data.size(), res_data.data());
}
static void vector_fixed(const ColumnString::Chars_t & data, size_t /*n*/, ColumnString::Chars_t & res_data)
{
res_data.resize(data.size());
array(data.data(), data.data() + data.size(), res_data.data());
}
private:
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
{
const auto flip_case_mask = 'A' ^ 'a';
#if __SSE2__
const auto bytes_sse = sizeof(__m128i);
const auto src_end_sse = src_end - (src_end - src) % bytes_sse;
const auto v_not_case_lower_bound = _mm_set1_epi8(not_case_lower_bound - 1);
const auto v_not_case_upper_bound = _mm_set1_epi8(not_case_upper_bound + 1);
const auto v_flip_case_mask = _mm_set1_epi8(flip_case_mask);
for (; src < src_end_sse; src += bytes_sse, dst += bytes_sse)
{
/// load 16 sequential 8-bit characters
const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src));
/// find which 8-bit sequences belong to range [case_lower_bound, case_upper_bound]
const auto is_not_case
= _mm_and_si128(_mm_cmpgt_epi8(chars, v_not_case_lower_bound), _mm_cmplt_epi8(chars, v_not_case_upper_bound));
/// keep `flip_case_mask` only where necessary, zero out elsewhere
const auto xor_mask = _mm_and_si128(v_flip_case_mask, is_not_case);
/// flip case by applying calculated mask
const auto cased_chars = _mm_xor_si128(chars, xor_mask);
/// store result back to destination
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst), cased_chars);
}
#endif
for (; src < src_end; ++src, ++dst)
if (*src >= not_case_lower_bound && *src <= not_case_upper_bound)
*dst = *src ^ flip_case_mask;
else
*dst = *src;
}
};
}

View File

@ -0,0 +1,229 @@
#include <Columns/ColumnString.h>
#include <Poco/UTF8Encoding.h>
#if __SSE2__
#include <emmintrin.h>
#endif
namespace DB
{
namespace
{
/// xor or do nothing
template <bool>
UInt8 xor_or_identity(const UInt8 c, const int mask)
{
return c ^ mask;
}
template <>
inline UInt8 xor_or_identity<false>(const UInt8 c, const int)
{
return c;
}
/// It is caller's responsibility to ensure the presence of a valid cyrillic sequence in array
template <bool to_lower>
inline void UTF8CyrillicToCase(const UInt8 *& src, UInt8 *& dst)
{
if (src[0] == 0xD0u && (src[1] >= 0x80u && src[1] <= 0x8Fu))
{
/// ЀЁЂЃЄЅІЇЈЉЊЋЌЍЎЏ
*dst++ = xor_or_identity<to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<to_lower>(*src++, 0x10);
}
else if (src[0] == 0xD1u && (src[1] >= 0x90u && src[1] <= 0x9Fu))
{
/// ѐёђѓєѕіїјљњћќѝўџ
*dst++ = xor_or_identity<!to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<!to_lower>(*src++, 0x10);
}
else if (src[0] == 0xD0u && (src[1] >= 0x90u && src[1] <= 0x9Fu))
{
/// А
*dst++ = *src++;
*dst++ = xor_or_identity<to_lower>(*src++, 0x20);
}
else if (src[0] == 0xD0u && (src[1] >= 0xB0u && src[1] <= 0xBFu))
{
/// а-п
*dst++ = *src++;
*dst++ = xor_or_identity<!to_lower>(*src++, 0x20);
}
else if (src[0] == 0xD0u && (src[1] >= 0xA0u && src[1] <= 0xAFu))
{
/// Р
*dst++ = xor_or_identity<to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<to_lower>(*src++, 0x20);
}
else if (src[0] == 0xD1u && (src[1] >= 0x80u && src[1] <= 0x8Fu))
{
/// р
*dst++ = xor_or_identity<!to_lower>(*src++, 0x1);
*dst++ = xor_or_identity<!to_lower>(*src++, 0x20);
}
}
}
/** If the string contains UTF-8 encoded text, convert it to the lower (upper) case.
* Note: It is assumed that after the character is converted to another case,
* the length of its multibyte sequence in UTF-8 does not change.
* Otherwise, the behavior is undefined.
*/
template <char not_case_lower_bound,
char not_case_upper_bound,
int to_case(int),
void cyrillic_to_case(const UInt8 *&, UInt8 *&)>
struct LowerUpperUTF8Impl
{
static void vector(
const ColumnString::Chars_t & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars_t & res_data,
ColumnString::Offsets & res_offsets)
{
res_data.resize(data.size());
res_offsets.assign(offsets);
array(data.data(), data.data() + data.size(), res_data.data());
}
static void vector_fixed(const ColumnString::Chars_t & data, size_t /*n*/, ColumnString::Chars_t & res_data)
{
res_data.resize(data.size());
array(data.data(), data.data() + data.size(), res_data.data());
}
static void constant(const std::string & data, std::string & res_data)
{
res_data.resize(data.size());
array(reinterpret_cast<const UInt8 *>(data.data()),
reinterpret_cast<const UInt8 *>(data.data() + data.size()),
reinterpret_cast<UInt8 *>(res_data.data()));
}
/** Converts a single code point starting at `src` to desired case, storing result starting at `dst`.
* `src` and `dst` are incremented by corresponding sequence lengths. */
static void toCase(const UInt8 *& src, const UInt8 * src_end, UInt8 *& dst)
{
if (src[0] <= ascii_upper_bound)
{
if (*src >= not_case_lower_bound && *src <= not_case_upper_bound)
*dst++ = *src++ ^ flip_case_mask;
else
*dst++ = *src++;
}
else if (src + 1 < src_end
&& ((src[0] == 0xD0u && (src[1] >= 0x80u && src[1] <= 0xBFu)) || (src[0] == 0xD1u && (src[1] >= 0x80u && src[1] <= 0x9Fu))))
{
cyrillic_to_case(src, dst);
}
else if (src + 1 < src_end && src[0] == 0xC2u)
{
/// Punctuation U+0080 - U+00BF, UTF-8: C2 80 - C2 BF
*dst++ = *src++;
*dst++ = *src++;
}
else if (src + 2 < src_end && src[0] == 0xE2u)
{
/// Characters U+2000 - U+2FFF, UTF-8: E2 80 80 - E2 BF BF
*dst++ = *src++;
*dst++ = *src++;
*dst++ = *src++;
}
else
{
static const Poco::UTF8Encoding utf8;
if (const auto chars = utf8.convert(to_case(utf8.convert(src)), dst, src_end - src))
{
src += chars;
dst += chars;
}
else
{
++src;
++dst;
}
}
}
private:
static constexpr auto ascii_upper_bound = '\x7f';
static constexpr auto flip_case_mask = 'A' ^ 'a';
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
{
#if __SSE2__
const auto bytes_sse = sizeof(__m128i);
auto src_end_sse = src + (src_end - src) / bytes_sse * bytes_sse;
/// SSE2 packed comparison operate on signed types, hence compare (c < 0) instead of (c > 0x7f)
const auto v_zero = _mm_setzero_si128();
const auto v_not_case_lower_bound = _mm_set1_epi8(not_case_lower_bound - 1);
const auto v_not_case_upper_bound = _mm_set1_epi8(not_case_upper_bound + 1);
const auto v_flip_case_mask = _mm_set1_epi8(flip_case_mask);
while (src < src_end_sse)
{
const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src));
/// check for ASCII
const auto is_not_ascii = _mm_cmplt_epi8(chars, v_zero);
const auto mask_is_not_ascii = _mm_movemask_epi8(is_not_ascii);
/// ASCII
if (mask_is_not_ascii == 0)
{
const auto is_not_case
= _mm_and_si128(_mm_cmpgt_epi8(chars, v_not_case_lower_bound), _mm_cmplt_epi8(chars, v_not_case_upper_bound));
const auto mask_is_not_case = _mm_movemask_epi8(is_not_case);
/// everything in correct case ASCII
if (mask_is_not_case == 0)
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst), chars);
else
{
/// ASCII in mixed case
/// keep `flip_case_mask` only where necessary, zero out elsewhere
const auto xor_mask = _mm_and_si128(v_flip_case_mask, is_not_case);
/// flip case by applying calculated mask
const auto cased_chars = _mm_xor_si128(chars, xor_mask);
/// store result back to destination
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst), cased_chars);
}
src += bytes_sse;
dst += bytes_sse;
}
else
{
/// UTF-8
const auto expected_end = src + bytes_sse;
while (src < expected_end)
toCase(src, src_end, dst);
/// adjust src_end_sse by pushing it forward or backward
const auto diff = src - expected_end;
if (diff != 0)
{
if (src_end_sse + diff < src_end)
src_end_sse += diff;
else
src_end_sse -= bytes_sse - diff;
}
}
}
#endif
/// handle remaining symbols
while (src < src_end)
toCase(src, src_end, dst);
}
};
}

View File

@ -0,0 +1,116 @@
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <ext/range.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}
class FunctionAppendTrailingCharIfAbsent : public IFunction
{
public:
static constexpr auto name = "appendTrailingCharIfAbsent";
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionAppendTrailingCharIfAbsent>();
}
String getName() const override
{
return name;
}
private:
size_t getNumberOfArguments() const override
{
return 2;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception{"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!isString(arguments[1]))
throw Exception{"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataTypeString>();
}
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const auto & column = block.getByPosition(arguments[0]).column;
const auto & column_char = block.getByPosition(arguments[1]).column;
if (!checkColumnConst<ColumnString>(column_char.get()))
throw Exception{"Second argument of function " + getName() + " must be a constant string", ErrorCodes::ILLEGAL_COLUMN};
String trailing_char_str = static_cast<const ColumnConst &>(*column_char).getValue<String>();
if (trailing_char_str.size() != 1)
throw Exception{"Second argument of function " + getName() + " must be a one-character string", ErrorCodes::BAD_ARGUMENTS};
if (const auto col = checkAndGetColumn<ColumnString>(column.get()))
{
auto col_res = ColumnString::create();
const auto & src_data = col->getChars();
const auto & src_offsets = col->getOffsets();
auto & dst_data = col_res->getChars();
auto & dst_offsets = col_res->getOffsets();
const auto size = src_offsets.size();
dst_data.resize(src_data.size() + size);
dst_offsets.resize(size);
ColumnString::Offset src_offset{};
ColumnString::Offset dst_offset{};
for (const auto i : ext::range(0, size))
{
const auto src_length = src_offsets[i] - src_offset;
memcpySmallAllowReadWriteOverflow15(&dst_data[dst_offset], &src_data[src_offset], src_length);
src_offset = src_offsets[i];
dst_offset += src_length;
if (src_length > 1 && dst_data[dst_offset - 2] != trailing_char_str.front())
{
dst_data[dst_offset - 1] = trailing_char_str.front();
dst_data[dst_offset] = 0;
++dst_offset;
}
dst_offsets[i] = dst_offset;
}
dst_data.resize_assume_reserved(dst_offset);
block.getByPosition(result).column = std::move(col_res);
}
else
throw Exception{"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
};
void registerFunctionAppendTrailingCharIfAbsent(FunctionFactory & factory)
{
factory.registerFunction<FunctionAppendTrailingCharIfAbsent>();
}
}

View File

@ -0,0 +1,119 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/getLeastSupertype.h>
#include <Columns/ColumnArray.h>
#include <Interpreters/castColumn.h>
namespace DB
{
/// array(c1, c2, ...) - create an array.
class FunctionArray : public IFunction
{
public:
static constexpr auto name = "array";
static FunctionPtr create(const Context & context)
{
return std::make_shared<FunctionArray>(context);
}
FunctionArray(const Context & context)
: context(context)
{
}
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
return std::make_shared<DataTypeArray>(getLeastSupertype(arguments));
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
size_t num_elements = arguments.size();
if (num_elements == 0)
{
/// We should return constant empty array.
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConstWithDefaultValue(input_rows_count);
return;
}
const DataTypePtr & return_type = block.getByPosition(result).type;
const DataTypePtr & elem_type = static_cast<const DataTypeArray &>(*return_type).getNestedType();
size_t block_size = input_rows_count;
/** If part of columns have not same type as common type of all elements of array,
* then convert them to common type.
* If part of columns are constants,
* then convert them to full columns.
*/
Columns columns_holder(num_elements);
const IColumn * columns[num_elements];
for (size_t i = 0; i < num_elements; ++i)
{
const auto & arg = block.getByPosition(arguments[i]);
ColumnPtr preprocessed_column = arg.column;
if (!arg.type->equals(*elem_type))
preprocessed_column = castColumn(arg, elem_type, context);
if (ColumnPtr materialized_column = preprocessed_column->convertToFullColumnIfConst())
preprocessed_column = materialized_column;
columns_holder[i] = std::move(preprocessed_column);
columns[i] = columns_holder[i].get();
}
/// Create and fill the result array.
auto out = ColumnArray::create(elem_type->createColumn());
IColumn & out_data = out->getData();
IColumn::Offsets & out_offsets = out->getOffsets();
out_data.reserve(block_size * num_elements);
out_offsets.resize(block_size);
IColumn::Offset current_offset = 0;
for (size_t i = 0; i < block_size; ++i)
{
for (size_t j = 0; j < num_elements; ++j)
out_data.insertFrom(*columns[j], i);
current_offset += num_elements;
out_offsets[i] = current_offset;
}
block.getByPosition(result).column = std::move(out);
}
private:
String getName() const override
{
return name;
}
bool addField(DataTypePtr type_res, const Field & f, Array & arr) const;
private:
const Context & context;
};
void registerFunctionArray(FunctionFactory & factory)
{
factory.registerFunction<FunctionArray>();
}
}

View File

@ -0,0 +1,117 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/GatherUtils/GatherUtils.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/getLeastSupertype.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Common/typeid_cast.h>
#include <ext/range.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// arrayConcat(arr1, ...) - concatenate arrays.
class FunctionArrayConcat : public IFunction
{
public:
static constexpr auto name = "arrayConcat";
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayConcat>(context); }
FunctionArrayConcat(const Context & context) : context(context) {}
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception{"Function " + getName() + " requires at least one argument.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
for (auto i : ext::range(0, arguments.size()))
{
auto array_type = typeid_cast<const DataTypeArray *>(arguments[i].get());
if (!array_type)
throw Exception("Argument " + std::to_string(i) + " for function " + getName() + " must be an array but it has type "
+ arguments[i]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return getLeastSupertype(arguments);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
const DataTypePtr & return_type = block.getByPosition(result).type;
if (return_type->onlyNull())
{
block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count);
return;
}
auto result_column = return_type->createColumn();
size_t rows = input_rows_count;
size_t num_args = arguments.size();
Columns preprocessed_columns(num_args);
for (size_t i = 0; i < num_args; ++i)
{
const ColumnWithTypeAndName & arg = block.getByPosition(arguments[i]);
ColumnPtr preprocessed_column = arg.column;
if (!arg.type->equals(*return_type))
preprocessed_column = castColumn(arg, return_type, context);
preprocessed_columns[i] = std::move(preprocessed_column);
}
std::vector<std::unique_ptr<GatherUtils::IArraySource>> sources;
for (auto & argument_column : preprocessed_columns)
{
bool is_const = false;
if (auto argument_column_const = typeid_cast<const ColumnConst *>(argument_column.get()))
{
is_const = true;
argument_column = argument_column_const->getDataColumnPtr();
}
if (auto argument_column_array = typeid_cast<const ColumnArray *>(argument_column.get()))
sources.emplace_back(GatherUtils::createArraySource(*argument_column_array, is_const, rows));
else
throw Exception{"Arguments for function " + getName() + " must be arrays.", ErrorCodes::LOGICAL_ERROR};
}
auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), rows);
GatherUtils::concat(sources, *sink);
block.getByPosition(result).column = std::move(result_column);
}
bool useDefaultImplementationForConstants() const override { return true; }
private:
const Context & context;
};
void registerFunctionArrayConcat(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayConcat>();
}
}

View File

@ -0,0 +1,302 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Common/HashTable/ClearableHashSet.h>
#include <Common/SipHash.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// Find different elements in an array.
class FunctionArrayDistinct : public IFunction
{
public:
static constexpr auto name = "arrayDistinct";
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionArrayDistinct>();
}
String getName() const override
{
return name;
}
bool isVariadic() const override { return false; }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception("Argument for function " + getName() + " must be array but it "
" has type " + arguments[0]->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto nested_type = removeNullable(array_type->getNestedType());
return std::make_shared<DataTypeArray>(nested_type);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
private:
/// Initially allocate a piece of memory for 512 elements. NOTE: This is just a guess.
static constexpr size_t INITIAL_SIZE_DEGREE = 9;
template <typename T>
bool executeNumber(
const IColumn & src_data,
const ColumnArray::Offsets & src_offsets,
IColumn & res_data_col,
ColumnArray::Offsets & res_offsets,
const ColumnNullable * nullable_col);
bool executeString(
const IColumn & src_data,
const ColumnArray::Offsets & src_offsets,
IColumn & res_data_col,
ColumnArray::Offsets & res_offsets,
const ColumnNullable * nullable_col);
void executeHashed(
const IColumn & src_data,
const ColumnArray::Offsets & src_offsets,
IColumn & res_data_col,
ColumnArray::Offsets & res_offsets,
const ColumnNullable * nullable_col);
};
void FunctionArrayDistinct::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
{
ColumnPtr array_ptr = block.getByPosition(arguments[0]).column;
const ColumnArray * array = checkAndGetColumn<ColumnArray>(array_ptr.get());
const auto & return_type = block.getByPosition(result).type;
auto res_ptr = return_type->createColumn();
ColumnArray & res = static_cast<ColumnArray &>(*res_ptr);
const IColumn & src_data = array->getData();
const ColumnArray::Offsets & offsets = array->getOffsets();
IColumn & res_data = res.getData();
ColumnArray::Offsets & res_offsets = res.getOffsets();
const ColumnNullable * nullable_col = nullptr;
const IColumn * inner_col;
if (src_data.isColumnNullable())
{
nullable_col = static_cast<const ColumnNullable *>(&src_data);
inner_col = &nullable_col->getNestedColumn();
}
else
{
inner_col = &src_data;
}
if (!(executeNumber<UInt8>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<UInt16>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<UInt32>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<UInt64>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<Int8>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<Int16>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<Int32>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<Int64>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<Float32>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeNumber<Float64>(*inner_col, offsets, res_data, res_offsets, nullable_col)
|| executeString(*inner_col, offsets, res_data, res_offsets, nullable_col)))
executeHashed(*inner_col, offsets, res_data, res_offsets, nullable_col);
block.getByPosition(result).column = std::move(res_ptr);
}
template <typename T>
bool FunctionArrayDistinct::executeNumber(
const IColumn & src_data,
const ColumnArray::Offsets & src_offsets,
IColumn & res_data_col,
ColumnArray::Offsets & res_offsets,
const ColumnNullable * nullable_col)
{
const ColumnVector<T> * src_data_concrete = checkAndGetColumn<ColumnVector<T>>(&src_data);
if (!src_data_concrete)
{
return false;
}
const PaddedPODArray<T> & values = src_data_concrete->getData();
PaddedPODArray<T> & res_data = typeid_cast<ColumnVector<T> &>(res_data_col).getData();
const PaddedPODArray<UInt8> * src_null_map = nullptr;
if (nullable_col)
src_null_map = &static_cast<const ColumnUInt8 *>(&nullable_col->getNullMapColumn())->getData();
using Set = ClearableHashSet<T,
DefaultHash<T>,
HashTableGrower<INITIAL_SIZE_DEGREE>,
HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(T)>>;
Set set;
ColumnArray::Offset prev_src_offset = 0;
ColumnArray::Offset res_offset = 0;
for (ColumnArray::Offset i = 0; i < src_offsets.size(); ++i)
{
set.clear();
ColumnArray::Offset curr_src_offset = src_offsets[i];
for (ColumnArray::Offset j = prev_src_offset; j < curr_src_offset; ++j)
{
if (nullable_col && (*src_null_map)[j])
continue;
if (set.find(values[j]) == set.end())
{
res_data.emplace_back(values[j]);
set.insert(values[j]);
}
}
res_offset += set.size();
res_offsets.emplace_back(res_offset);
prev_src_offset = curr_src_offset;
}
return true;
}
bool FunctionArrayDistinct::executeString(
const IColumn & src_data,
const ColumnArray::Offsets & src_offsets,
IColumn & res_data_col,
ColumnArray::Offsets & res_offsets,
const ColumnNullable * nullable_col)
{
const ColumnString * src_data_concrete = checkAndGetColumn<ColumnString>(&src_data);
if (!src_data_concrete)
return false;
ColumnString & res_data_column_string = typeid_cast<ColumnString &>(res_data_col);
using Set = ClearableHashSet<StringRef,
StringRefHash,
HashTableGrower<INITIAL_SIZE_DEGREE>,
HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(StringRef)>>;
const PaddedPODArray<UInt8> * src_null_map = nullptr;
if (nullable_col)
src_null_map = &static_cast<const ColumnUInt8 *>(&nullable_col->getNullMapColumn())->getData();
Set set;
ColumnArray::Offset prev_src_offset = 0;
ColumnArray::Offset res_offset = 0;
for (ColumnArray::Offset i = 0; i < src_offsets.size(); ++i)
{
set.clear();
ColumnArray::Offset curr_src_offset = src_offsets[i];
for (ColumnArray::Offset j = prev_src_offset; j < curr_src_offset; ++j)
{
if (nullable_col && (*src_null_map)[j])
continue;
StringRef str_ref = src_data_concrete->getDataAt(j);
if (set.find(str_ref) == set.end())
{
set.insert(str_ref);
res_data_column_string.insertData(str_ref.data, str_ref.size);
}
}
res_offset += set.size();
res_offsets.emplace_back(res_offset);
prev_src_offset = curr_src_offset;
}
return true;
}
void FunctionArrayDistinct::executeHashed(
const IColumn & src_data,
const ColumnArray::Offsets & src_offsets,
IColumn & res_data_col,
ColumnArray::Offsets & res_offsets,
const ColumnNullable * nullable_col)
{
using Set = ClearableHashSet<UInt128, UInt128TrivialHash, HashTableGrower<INITIAL_SIZE_DEGREE>,
HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(UInt128)>>;
const PaddedPODArray<UInt8> * src_null_map = nullptr;
if (nullable_col)
src_null_map = &static_cast<const ColumnUInt8 *>(&nullable_col->getNullMapColumn())->getData();
Set set;
ColumnArray::Offset prev_src_offset = 0;
ColumnArray::Offset res_offset = 0;
for (ColumnArray::Offset i = 0; i < src_offsets.size(); ++i)
{
set.clear();
ColumnArray::Offset curr_src_offset = src_offsets[i];
for (ColumnArray::Offset j = prev_src_offset; j < curr_src_offset; ++j)
{
if (nullable_col && (*src_null_map)[j])
continue;
UInt128 hash;
SipHash hash_function;
src_data.updateHashWithValue(j, hash_function);
hash_function.get128(reinterpret_cast<char *>(&hash));
if (set.find(hash) == set.end())
{
set.insert(hash);
res_data_col.insertFrom(src_data, j);
}
}
res_offset += set.size();
res_offsets.emplace_back(res_offset);
prev_src_offset = curr_src_offset;
}
}
void registerFunctionArrayDistinct(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayDistinct>();
}
}

View File

@ -0,0 +1,886 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <Core/ColumnNumbers.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Common/typeid_cast.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ZERO_ARRAY_OR_TUPLE_INDEX;
}
namespace ArrayImpl
{
class NullMapBuilder;
}
/** arrayElement(arr, i) - get the array element by index. If index is not constant and out of range - return default value of data type.
* The index begins with 1. Also, the index can be negative - then it is counted from the end of the array.
*/
class FunctionArrayElement : public IFunction
{
public:
static constexpr auto name = "arrayElement";
static FunctionPtr create(const Context & context);
String getName() const override;
bool useDefaultImplementationForConstants() const override { return true; }
size_t getNumberOfArguments() const override { return 2; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
private:
void perform(Block & block, const ColumnNumbers & arguments, size_t result,
ArrayImpl::NullMapBuilder & builder, size_t input_rows_count);
template <typename DataType>
bool executeNumberConst(Block & block, const ColumnNumbers & arguments, size_t result, const Field & index,
ArrayImpl::NullMapBuilder & builder);
template <typename IndexType, typename DataType>
bool executeNumber(Block & block, const ColumnNumbers & arguments, size_t result, const PaddedPODArray<IndexType> & indices,
ArrayImpl::NullMapBuilder & builder);
bool executeStringConst(Block & block, const ColumnNumbers & arguments, size_t result, const Field & index,
ArrayImpl::NullMapBuilder & builder);
template <typename IndexType>
bool executeString(Block & block, const ColumnNumbers & arguments, size_t result, const PaddedPODArray<IndexType> & indices,
ArrayImpl::NullMapBuilder & builder);
bool executeGenericConst(Block & block, const ColumnNumbers & arguments, size_t result, const Field & index,
ArrayImpl::NullMapBuilder & builder);
template <typename IndexType>
bool executeGeneric(Block & block, const ColumnNumbers & arguments, size_t result, const PaddedPODArray<IndexType> & indices,
ArrayImpl::NullMapBuilder & builder);
template <typename IndexType>
bool executeConst(Block & block, const ColumnNumbers & arguments, size_t result,
const PaddedPODArray <IndexType> & indices, ArrayImpl::NullMapBuilder & builder,
size_t input_rows_count);
template <typename IndexType>
bool executeArgument(Block & block, const ColumnNumbers & arguments, size_t result,
ArrayImpl::NullMapBuilder & builder, size_t input_rows_count);
/** For a tuple array, the function is evaluated component-wise for each element of the tuple.
*/
bool executeTuple(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count);
};
namespace ArrayImpl
{
class NullMapBuilder
{
public:
operator bool() const { return src_null_map; }
bool operator!() const { return !src_null_map; }
void initSource(const UInt8 * src_null_map_)
{
src_null_map = src_null_map_;
}
void initSink(size_t size)
{
auto sink = ColumnUInt8::create(size);
sink_null_map = sink->getData().data();
sink_null_map_holder = std::move(sink);
}
void update(size_t from)
{
sink_null_map[index] = bool(src_null_map && src_null_map[from]);
++index;
}
void update()
{
sink_null_map[index] = bool(src_null_map);
++index;
}
ColumnPtr getNullMapColumnPtr() && { return std::move(sink_null_map_holder); }
private:
const UInt8 * src_null_map = nullptr;
UInt8 * sink_null_map = nullptr;
MutableColumnPtr sink_null_map_holder;
size_t index = 0;
};
}
namespace
{
template <typename T>
struct ArrayElementNumImpl
{
/** Implementation for constant index.
* If negative = false - index is from beginning of array, started from 0.
* If negative = true - index is from end of array, started from 0.
*/
template <bool negative>
static void vectorConst(
const PaddedPODArray<T> & data, const ColumnArray::Offsets & offsets,
const ColumnArray::Offset index,
PaddedPODArray<T> & result, ArrayImpl::NullMapBuilder & builder)
{
size_t size = offsets.size();
result.resize(size);
ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
if (index < array_size)
{
size_t j = !negative ? (current_offset + index) : (offsets[i] - index - 1);
result[i] = data[j];
if (builder)
builder.update(j);
}
else
{
result[i] = T();
if (builder)
builder.update();
}
current_offset = offsets[i];
}
}
/** Implementation for non-constant index.
*/
template <typename TIndex>
static void vector(
const PaddedPODArray<T> & data, const ColumnArray::Offsets & offsets,
const PaddedPODArray<TIndex> & indices,
PaddedPODArray<T> & result, ArrayImpl::NullMapBuilder & builder)
{
size_t size = offsets.size();
result.resize(size);
ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
TIndex index = indices[i];
if (index > 0 && static_cast<size_t>(index) <= array_size)
{
size_t j = current_offset + index - 1;
result[i] = data[j];
if (builder)
builder.update(j);
}
else if (index < 0 && static_cast<size_t>(-index) <= array_size)
{
size_t j = offsets[i] + index;
result[i] = data[j];
if (builder)
builder.update(j);
}
else
{
result[i] = T();
if (builder)
builder.update();
}
current_offset = offsets[i];
}
}
};
struct ArrayElementStringImpl
{
template <bool negative>
static void vectorConst(
const ColumnString::Chars_t & data, const ColumnArray::Offsets & offsets, const ColumnString::Offsets & string_offsets,
const ColumnArray::Offset index,
ColumnString::Chars_t & result_data, ColumnArray::Offsets & result_offsets,
ArrayImpl::NullMapBuilder & builder)
{
size_t size = offsets.size();
result_offsets.resize(size);
result_data.reserve(data.size());
ColumnArray::Offset current_offset = 0;
ColumnArray::Offset current_result_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
if (index < array_size)
{
size_t adjusted_index = !negative ? index : (array_size - index - 1);
size_t j = current_offset + adjusted_index;
if (builder)
builder.update(j);
ColumnArray::Offset string_pos = current_offset == 0 && adjusted_index == 0
? 0
: string_offsets[current_offset + adjusted_index - 1];
ColumnArray::Offset string_size = string_offsets[current_offset + adjusted_index] - string_pos;
result_data.resize(current_result_offset + string_size);
memcpySmallAllowReadWriteOverflow15(&result_data[current_result_offset], &data[string_pos], string_size);
current_result_offset += string_size;
result_offsets[i] = current_result_offset;
}
else
{
/// Insert an empty row.
result_data.resize(current_result_offset + 1);
result_data[current_result_offset] = 0;
current_result_offset += 1;
result_offsets[i] = current_result_offset;
if (builder)
builder.update();
}
current_offset = offsets[i];
}
}
/** Implementation for non-constant index.
*/
template <typename TIndex>
static void vector(
const ColumnString::Chars_t & data, const ColumnArray::Offsets & offsets, const ColumnString::Offsets & string_offsets,
const PaddedPODArray<TIndex> & indices,
ColumnString::Chars_t & result_data, ColumnArray::Offsets & result_offsets,
ArrayImpl::NullMapBuilder & builder)
{
size_t size = offsets.size();
result_offsets.resize(size);
result_data.reserve(data.size());
ColumnArray::Offset current_offset = 0;
ColumnArray::Offset current_result_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
size_t adjusted_index; /// index in array from zero
TIndex index = indices[i];
if (index > 0 && static_cast<size_t>(index) <= array_size)
adjusted_index = index - 1;
else if (index < 0 && static_cast<size_t>(-index) <= array_size)
adjusted_index = array_size + index;
else
adjusted_index = array_size; /// means no element should be taken
if (adjusted_index < array_size)
{
size_t j = current_offset + adjusted_index;
if (builder)
builder.update(j);
ColumnArray::Offset string_pos = current_offset == 0 && adjusted_index == 0
? 0
: string_offsets[current_offset + adjusted_index - 1];
ColumnArray::Offset string_size = string_offsets[current_offset + adjusted_index] - string_pos;
result_data.resize(current_result_offset + string_size);
memcpySmallAllowReadWriteOverflow15(&result_data[current_result_offset], &data[string_pos], string_size);
current_result_offset += string_size;
result_offsets[i] = current_result_offset;
}
else
{
/// Insert empty string
result_data.resize(current_result_offset + 1);
result_data[current_result_offset] = 0;
current_result_offset += 1;
result_offsets[i] = current_result_offset;
if (builder)
builder.update();
}
current_offset = offsets[i];
}
}
};
/// Generic implementation for other nested types.
struct ArrayElementGenericImpl
{
template <bool negative>
static void vectorConst(
const IColumn & data, const ColumnArray::Offsets & offsets,
const ColumnArray::Offset index,
IColumn & result, ArrayImpl::NullMapBuilder & builder)
{
size_t size = offsets.size();
result.reserve(size);
ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
if (index < array_size)
{
size_t j = !negative ? current_offset + index : offsets[i] - index - 1;
result.insertFrom(data, j);
if (builder)
builder.update(j);
}
else
{
result.insertDefault();
if (builder)
builder.update();
}
current_offset = offsets[i];
}
}
/** Implementation for non-constant index.
*/
template <typename TIndex>
static void vector(
const IColumn & data, const ColumnArray::Offsets & offsets,
const PaddedPODArray<TIndex> & indices,
IColumn & result, ArrayImpl::NullMapBuilder & builder)
{
size_t size = offsets.size();
result.reserve(size);
ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
TIndex index = indices[i];
if (index > 0 && static_cast<size_t>(index) <= array_size)
{
size_t j = current_offset + index - 1;
result.insertFrom(data, j);
if (builder)
builder.update(j);
}
else if (index < 0 && static_cast<size_t>(-index) <= array_size)
{
size_t j = offsets[i] + index;
result.insertFrom(data, j);
if (builder)
builder.update(j);
}
else
{
result.insertDefault();
if (builder)
builder.update();
}
current_offset = offsets[i];
}
}
};
}
FunctionPtr FunctionArrayElement::create(const Context &)
{
return std::make_shared<FunctionArrayElement>();
}
template <typename DataType>
bool FunctionArrayElement::executeNumberConst(Block & block, const ColumnNumbers & arguments, size_t result, const Field & index,
ArrayImpl::NullMapBuilder & builder)
{
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const ColumnVector<DataType> * col_nested = checkAndGetColumn<ColumnVector<DataType>>(&col_array->getData());
if (!col_nested)
return false;
auto col_res = ColumnVector<DataType>::create();
if (index.getType() == Field::Types::UInt64)
ArrayElementNumImpl<DataType>::template vectorConst<false>(
col_nested->getData(), col_array->getOffsets(), safeGet<UInt64>(index) - 1, col_res->getData(), builder);
else if (index.getType() == Field::Types::Int64)
ArrayElementNumImpl<DataType>::template vectorConst<true>(
col_nested->getData(), col_array->getOffsets(), -safeGet<Int64>(index) - 1, col_res->getData(), builder);
else
throw Exception("Illegal type of array index", ErrorCodes::LOGICAL_ERROR);
block.getByPosition(result).column = std::move(col_res);
return true;
}
template <typename IndexType, typename DataType>
bool FunctionArrayElement::executeNumber(Block & block, const ColumnNumbers & arguments, size_t result, const PaddedPODArray<IndexType> & indices,
ArrayImpl::NullMapBuilder & builder)
{
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const ColumnVector<DataType> * col_nested = checkAndGetColumn<ColumnVector<DataType>>(&col_array->getData());
if (!col_nested)
return false;
auto col_res = ColumnVector<DataType>::create();
ArrayElementNumImpl<DataType>::template vector<IndexType>(
col_nested->getData(), col_array->getOffsets(), indices, col_res->getData(), builder);
block.getByPosition(result).column = std::move(col_res);
return true;
}
bool FunctionArrayElement::executeStringConst(Block & block, const ColumnNumbers & arguments, size_t result, const Field & index,
ArrayImpl::NullMapBuilder & builder)
{
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const ColumnString * col_nested = checkAndGetColumn<ColumnString>(&col_array->getData());
if (!col_nested)
return false;
auto col_res = ColumnString::create();
if (index.getType() == Field::Types::UInt64)
ArrayElementStringImpl::vectorConst<false>(
col_nested->getChars(),
col_array->getOffsets(),
col_nested->getOffsets(),
safeGet<UInt64>(index) - 1,
col_res->getChars(),
col_res->getOffsets(),
builder);
else if (index.getType() == Field::Types::Int64)
ArrayElementStringImpl::vectorConst<true>(
col_nested->getChars(),
col_array->getOffsets(),
col_nested->getOffsets(),
-safeGet<Int64>(index) - 1,
col_res->getChars(),
col_res->getOffsets(),
builder);
else
throw Exception("Illegal type of array index", ErrorCodes::LOGICAL_ERROR);
block.getByPosition(result).column = std::move(col_res);
return true;
}
template <typename IndexType>
bool FunctionArrayElement::executeString(Block & block, const ColumnNumbers & arguments, size_t result, const PaddedPODArray<IndexType> & indices,
ArrayImpl::NullMapBuilder & builder)
{
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const ColumnString * col_nested = checkAndGetColumn<ColumnString>(&col_array->getData());
if (!col_nested)
return false;
auto col_res = ColumnString::create();
ArrayElementStringImpl::vector<IndexType>(
col_nested->getChars(),
col_array->getOffsets(),
col_nested->getOffsets(),
indices,
col_res->getChars(),
col_res->getOffsets(),
builder);
block.getByPosition(result).column = std::move(col_res);
return true;
}
bool FunctionArrayElement::executeGenericConst(Block & block, const ColumnNumbers & arguments, size_t result, const Field & index,
ArrayImpl::NullMapBuilder & builder)
{
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const auto & col_nested = col_array->getData();
auto col_res = col_nested.cloneEmpty();
if (index.getType() == Field::Types::UInt64)
ArrayElementGenericImpl::vectorConst<false>(
col_nested, col_array->getOffsets(), safeGet<UInt64>(index) - 1, *col_res, builder);
else if (index.getType() == Field::Types::Int64)
ArrayElementGenericImpl::vectorConst<true>(
col_nested, col_array->getOffsets(), -safeGet<Int64>(index) - 1, *col_res, builder);
else
throw Exception("Illegal type of array index", ErrorCodes::LOGICAL_ERROR);
block.getByPosition(result).column = std::move(col_res);
return true;
}
template <typename IndexType>
bool FunctionArrayElement::executeGeneric(Block & block, const ColumnNumbers & arguments, size_t result, const PaddedPODArray<IndexType> & indices,
ArrayImpl::NullMapBuilder & builder)
{
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const auto & col_nested = col_array->getData();
auto col_res = col_nested.cloneEmpty();
ArrayElementGenericImpl::vector<IndexType>(
col_nested, col_array->getOffsets(), indices, *col_res, builder);
block.getByPosition(result).column = std::move(col_res);
return true;
}
template <typename IndexType>
bool FunctionArrayElement::executeConst(Block & block, const ColumnNumbers & arguments, size_t result,
const PaddedPODArray <IndexType> & indices, ArrayImpl::NullMapBuilder & builder,
size_t input_rows_count)
{
const ColumnArray * col_array = checkAndGetColumnConstData<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
auto res = block.getByPosition(result).type->createColumn();
size_t rows = input_rows_count;
const IColumn & array_elements = col_array->getData();
size_t array_size = array_elements.size();
for (size_t i = 0; i < rows; ++i)
{
IndexType index = indices[i];
if (index > 0 && static_cast<size_t>(index) <= array_size)
{
size_t j = index - 1;
res->insertFrom(array_elements, j);
if (builder)
builder.update(j);
}
else if (index < 0 && static_cast<size_t>(-index) <= array_size)
{
size_t j = array_size + index;
res->insertFrom(array_elements, j);
if (builder)
builder.update(j);
}
else
{
res->insertDefault();
if (builder)
builder.update();
}
}
block.getByPosition(result).column = std::move(res);
return true;
}
template <typename IndexType>
bool FunctionArrayElement::executeArgument(Block & block, const ColumnNumbers & arguments, size_t result,
ArrayImpl::NullMapBuilder & builder, size_t input_rows_count)
{
auto index = checkAndGetColumn<ColumnVector<IndexType>>(block.getByPosition(arguments[1]).column.get());
if (!index)
return false;
const auto & index_data = index->getData();
if (builder)
builder.initSink(index_data.size());
if (!( executeNumber<IndexType, UInt8>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, UInt16>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, UInt32>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, UInt64>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, Int8>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, Int16>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, Int32>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, Int64>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, Float32>(block, arguments, result, index_data, builder)
|| executeNumber<IndexType, Float64>(block, arguments, result, index_data, builder)
|| executeConst<IndexType>(block, arguments, result, index_data, builder, input_rows_count)
|| executeString<IndexType>(block, arguments, result, index_data, builder)
|| executeGeneric<IndexType>(block, arguments, result, index_data, builder)))
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
return true;
}
bool FunctionArrayElement::executeTuple(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
const ColumnArray * col_array = typeid_cast<const ColumnArray *>(block.getByPosition(arguments[0]).column.get());
if (!col_array)
return false;
const ColumnTuple * col_nested = typeid_cast<const ColumnTuple *>(&col_array->getData());
if (!col_nested)
return false;
const Columns & tuple_columns = col_nested->getColumns();
size_t tuple_size = tuple_columns.size();
const DataTypes & tuple_types = typeid_cast<const DataTypeTuple &>(
*typeid_cast<const DataTypeArray &>(*block.getByPosition(arguments[0]).type).getNestedType()).getElements();
/** We will calculate the function for the tuple of the internals of the array.
* To do this, create a temporary block.
* It will consist of the following columns
* - the index of the array to be taken;
* - an array of the first elements of the tuples;
* - the result of taking the elements by the index for an array of the first elements of the tuples;
* - array of the second elements of the tuples;
* - result of taking elements by index for an array of second elements of tuples;
* ...
*/
Block block_of_temporary_results;
block_of_temporary_results.insert(block.getByPosition(arguments[1]));
/// results of taking elements by index for arrays from each element of the tuples;
Columns result_tuple_columns;
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName array_of_tuple_section;
array_of_tuple_section.column = ColumnArray::create(tuple_columns[i], col_array->getOffsetsPtr());
array_of_tuple_section.type = std::make_shared<DataTypeArray>(tuple_types[i]);
block_of_temporary_results.insert(array_of_tuple_section);
ColumnWithTypeAndName array_elements_of_tuple_section;
array_elements_of_tuple_section.type = getReturnTypeImpl(
{block_of_temporary_results.getByPosition(i * 2 + 1).type, block_of_temporary_results.getByPosition(0).type});
block_of_temporary_results.insert(array_elements_of_tuple_section);
executeImpl(block_of_temporary_results, ColumnNumbers{i * 2 + 1, 0}, i * 2 + 2, input_rows_count);
result_tuple_columns.emplace_back(std::move(block_of_temporary_results.getByPosition(i * 2 + 2).column));
}
block.getByPosition(result).column = ColumnTuple::create(result_tuple_columns);
return true;
}
String FunctionArrayElement::getName() const
{
return name;
}
DataTypePtr FunctionArrayElement::getReturnTypeImpl(const DataTypes & arguments) const
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception("First argument for function " + getName() + " must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!isInteger(arguments[1]))
throw Exception("Second argument for function " + getName() + " must be integer.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return array_type->getNestedType();
}
void FunctionArrayElement::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
/// Check nullability.
bool is_array_of_nullable = false;
const ColumnArray * col_array = nullptr;
const ColumnArray * col_const_array = nullptr;
col_array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (col_array)
is_array_of_nullable = col_array->getData().isColumnNullable();
else
{
col_const_array = checkAndGetColumnConstData<ColumnArray>(block.getByPosition(arguments[0]).column.get());
if (col_const_array)
is_array_of_nullable = col_const_array->getData().isColumnNullable();
else
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
}
if (!is_array_of_nullable)
{
ArrayImpl::NullMapBuilder builder;
perform(block, arguments, result, builder, input_rows_count);
}
else
{
/// Perform initializations.
ArrayImpl::NullMapBuilder builder;
Block source_block;
const auto & input_type = typeid_cast<const DataTypeNullable &>(*typeid_cast<const DataTypeArray &>(*block.getByPosition(arguments[0]).type).getNestedType()).getNestedType();
const auto & tmp_ret_type = typeid_cast<const DataTypeNullable &>(*block.getByPosition(result).type).getNestedType();
if (col_array)
{
const auto & nullable_col = typeid_cast<const ColumnNullable &>(col_array->getData());
const auto & nested_col = nullable_col.getNestedColumnPtr();
/// Put nested_col inside a ColumnArray.
source_block =
{
{
ColumnArray::create(nested_col, col_array->getOffsetsPtr()),
std::make_shared<DataTypeArray>(input_type),
""
},
block.getByPosition(arguments[1]),
{
nullptr,
tmp_ret_type,
""
}
};
builder.initSource(nullable_col.getNullMapData().data());
}
else
{
/// ColumnConst(ColumnArray(ColumnNullable(...)))
const auto & nullable_col = static_cast<const ColumnNullable &>(col_const_array->getData());
const auto & nested_col = nullable_col.getNestedColumnPtr();
source_block =
{
{
ColumnConst::create(ColumnArray::create(nested_col, col_const_array->getOffsetsPtr()), input_rows_count),
std::make_shared<DataTypeArray>(input_type),
""
},
block.getByPosition(arguments[1]),
{
nullptr,
tmp_ret_type,
""
}
};
builder.initSource(nullable_col.getNullMapData().data());
}
perform(source_block, {0, 1}, 2, builder, input_rows_count);
/// Store the result.
const ColumnWithTypeAndName & source_col = source_block.getByPosition(2);
ColumnWithTypeAndName & dest_col = block.getByPosition(result);
dest_col.column = ColumnNullable::create(source_col.column, builder ? std::move(builder).getNullMapColumnPtr() : ColumnUInt8::create());
}
}
void FunctionArrayElement::perform(Block & block, const ColumnNumbers & arguments, size_t result,
ArrayImpl::NullMapBuilder & builder, size_t input_rows_count)
{
if (executeTuple(block, arguments, result, input_rows_count))
{
}
else if (!block.getByPosition(arguments[1]).column->isColumnConst())
{
if (!(executeArgument<UInt8>(block, arguments, result, builder, input_rows_count)
|| executeArgument<UInt16>(block, arguments, result, builder, input_rows_count)
|| executeArgument<UInt32>(block, arguments, result, builder, input_rows_count)
|| executeArgument<UInt64>(block, arguments, result, builder, input_rows_count)
|| executeArgument<Int8>(block, arguments, result, builder, input_rows_count)
|| executeArgument<Int16>(block, arguments, result, builder, input_rows_count)
|| executeArgument<Int32>(block, arguments, result, builder, input_rows_count)
|| executeArgument<Int64>(block, arguments, result, builder, input_rows_count)))
throw Exception("Second argument for function " + getName() + " must must have UInt or Int type.",
ErrorCodes::ILLEGAL_COLUMN);
}
else
{
Field index = (*block.getByPosition(arguments[1]).column)[0];
if (builder)
builder.initSink(input_rows_count);
if (index == UInt64(0))
throw Exception("Array indices is 1-based", ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX);
if (!( executeNumberConst<UInt8>(block, arguments, result, index, builder)
|| executeNumberConst<UInt16>(block, arguments, result, index, builder)
|| executeNumberConst<UInt32>(block, arguments, result, index, builder)
|| executeNumberConst<UInt64>(block, arguments, result, index, builder)
|| executeNumberConst<Int8>(block, arguments, result, index, builder)
|| executeNumberConst<Int16>(block, arguments, result, index, builder)
|| executeNumberConst<Int32>(block, arguments, result, index, builder)
|| executeNumberConst<Int64>(block, arguments, result, index, builder)
|| executeNumberConst<Float32>(block, arguments, result, index, builder)
|| executeNumberConst<Float64>(block, arguments, result, index, builder)
|| executeStringConst (block, arguments, result, index, builder)
|| executeGenericConst (block, arguments, result, index, builder)))
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
}
void registerFunctionArrayElement(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayElement>();
}
}

View File

@ -0,0 +1,85 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// arrayEnumerate(arr) - Returns the array [1,2,3,..., length(arr)]
class FunctionArrayEnumerate : public IFunction
{
public:
static constexpr auto name = "arrayEnumerate";
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionArrayEnumerate>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception("First argument for function " + getName() + " must be an array but it has type "
+ arguments[0]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt32>());
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t) override
{
if (const ColumnArray * array = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get()))
{
const ColumnArray::Offsets & offsets = array->getOffsets();
auto res_nested = ColumnUInt32::create();
ColumnUInt32::Container & res_values = res_nested->getData();
res_values.resize(array->getData().size());
ColumnArray::Offset prev_off = 0;
for (ColumnArray::Offset i = 0; i < offsets.size(); ++i)
{
ColumnArray::Offset off = offsets[i];
for (ColumnArray::Offset j = prev_off; j < off; ++j)
res_values[j] = j - prev_off + 1;
prev_off = off;
}
block.getByPosition(result).column = ColumnArray::create(std::move(res_nested), array->getOffsetsPtr());
}
else
{
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
}
};
void registerFunctionArrayEnumerate(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayEnumerate>();
}
}

View File

@ -0,0 +1,22 @@
#include <Functions/arrayEnumerateExtended.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
class FunctionArrayEnumerateDense : public FunctionArrayEnumerateExtended<FunctionArrayEnumerateDense>
{
using Base = FunctionArrayEnumerateExtended<FunctionArrayEnumerateDense>;
public:
static constexpr auto name = "arrayEnumerateDense";
using Base::create;
};
void registerFunctionArrayEnumerateDense(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayEnumerateDense>();
}
}

Some files were not shown because too many files have changed in this diff Show More