Less clumsy code with interruptable code

This commit is contained in:
alesapin 2020-11-24 17:02:55 +03:00
parent c1a7e4f5fa
commit 0d52cfb1be
4 changed files with 196 additions and 55 deletions

View File

@ -708,29 +708,34 @@ void TestKeeperStorage::putCloseRequest(const Coordination::ZooKeeperRequestPtr
throw Exception("Cannot push request to queue within operation timeout", ErrorCodes::LOGICAL_ERROR);
}
TestKeeperStorage::ResponsePair TestKeeperStorage::putRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id)
void TestKeeperStorage::putRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id, ResponseCallback callback)
{
auto promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>();
auto future = promise->get_future();
TestKeeperStorageRequestPtr storage_request = TestKeeperWrapperFactory::instance().get(request);
RequestInfo request_info;
request_info.time = clock::now();
request_info.request = storage_request;
request_info.session_id = session_id;
request_info.response_callback = [promise] (const Coordination::ZooKeeperResponsePtr & response) { promise->set_value(response); };
std::optional<AsyncResponse> watch_future;
request_info.response_callback = callback;
std::lock_guard lock(push_request_mutex);
if (!requests_queue.tryPush(std::move(request_info), operation_timeout.totalMilliseconds()))
throw Exception("Cannot push request to queue within operation timeout", ErrorCodes::LOGICAL_ERROR);
}
void TestKeeperStorage::putRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id, ResponseCallback callback, ResponseCallback watch_callback)
{
TestKeeperStorageRequestPtr storage_request = TestKeeperWrapperFactory::instance().get(request);
RequestInfo request_info;
request_info.time = clock::now();
request_info.request = storage_request;
request_info.session_id = session_id;
request_info.response_callback = callback;
if (request->has_watch)
{
auto watch_promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>();
watch_future.emplace(watch_promise->get_future());
request_info.watch_callback = [watch_promise] (const Coordination::ZooKeeperResponsePtr & response) { watch_promise->set_value(response); };
}
request_info.watch_callback = watch_callback;
std::lock_guard lock(push_request_mutex);
if (!requests_queue.tryPush(std::move(request_info), operation_timeout.totalMilliseconds()))
throw Exception("Cannot push request to queue within operation timeout", ErrorCodes::LOGICAL_ERROR);
return ResponsePair{std::move(future), std::move(watch_future)};
}
TestKeeperStorage::~TestKeeperStorage()

View File

@ -79,7 +79,9 @@ public:
AsyncResponse response;
std::optional<AsyncResponse> watch_response;
};
ResponsePair putRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id);
void putRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id, ResponseCallback callback);
void putRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id, ResponseCallback callback, ResponseCallback watch_callback);
void putCloseRequest(const Coordination::ZooKeeperRequestPtr & request, int64_t session_id);
int64_t getSessionID()

View File

@ -11,12 +11,126 @@
#include <common/logger_useful.h>
#include <chrono>
#include <sys/eventfd.h>
#include <sys/epoll.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int SYSTEM_ERROR;
}
struct SocketInterruptablePollWrapper
{
int sockfd;
int efd;
enum class PollStatus
{
HAS_DATA,
TIMEOUT,
INTERRUPTED,
ERROR,
};
using InterruptCallback = std::function<void()>;
SocketInterruptablePollWrapper(const Poco::Net::StreamSocket & poco_socket_)
: sockfd(poco_socket_.impl()->sockfd())
, efd(eventfd(0, EFD_NONBLOCK))
{
if (efd < 0)
throwFromErrno("Cannot create eventfd file descriptor", ErrorCodes::SYSTEM_ERROR);
}
~SocketInterruptablePollWrapper()
{
if (efd >= 0)
{
close(efd);
efd = -1;
}
}
void interruptPoll()
{
UInt64 bytes = 1;
int ret = write(efd, &bytes, sizeof(bytes));
if (ret < 0)
throwFromErrno("Cannot write into eventfd descriptor", ErrorCodes::SYSTEM_ERROR);
}
PollStatus poll(Poco::Timespan& remainingTime)
{
int epollfd = epoll_create(2);
if (epollfd < 0)
throwFromErrno("Cannot epoll_create", ErrorCodes::SYSTEM_ERROR);
epoll_event socket_event{};
socket_event.events = EPOLLIN | EPOLLERR;
socket_event.data.fd = sockfd;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &socket_event) < 0)
{
::close(epollfd);
throwFromErrno("Cannot insert socket into epoll queue", ErrorCodes::SYSTEM_ERROR);
}
epoll_event efd_event{};
efd_event.events = EPOLLIN | EPOLLERR;
efd_event.data.fd = efd;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, efd, &efd_event) < 0)
{
::close(epollfd);
throwFromErrno("Cannot insert socket into epoll queue", ErrorCodes::SYSTEM_ERROR);
}
int rc;
epoll_event evout{};
do
{
Poco::Timestamp start;
rc = epoll_wait(epollfd, &evout, 1, remainingTime.totalMilliseconds());
if (rc < 0 && errno == EINTR)
{
Poco::Timestamp end;
Poco::Timespan waited = end - start;
if (waited < remainingTime)
remainingTime -= waited;
else
remainingTime = 0;
}
}
while (rc < 0 && errno == EINTR);
::close(epollfd);
if (rc < 0)
return PollStatus::ERROR;
else if (rc == 0)
return PollStatus::TIMEOUT;
else if (evout.data.fd == efd)
{
UInt64 bytes;
if (read(efd, &bytes, sizeof(bytes)) < 0)
throwFromErrno("Cannot read from eventfd", ErrorCodes::SYSTEM_ERROR);
return PollStatus::INTERRUPTED;
}
return PollStatus::HAS_DATA;
}
};
TestKeeperTCPHandler::TestKeeperTCPHandler(IServer & server_, const Poco::Net::StreamSocket & socket_)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, log(&Poco::Logger::get("TestKeeperTCPHandler"))
, global_context(server.context())
, test_keeper_storage(global_context.getTestKeeperStorage())
, operation_timeout(0, Coordination::DEFAULT_OPERATION_TIMEOUT_MS * 1000)
, session_timeout(0, Coordination::DEFAULT_SESSION_TIMEOUT_MS * 1000)
, session_id(test_keeper_storage->getSessionID())
, poll_wrapper(std::make_unique<SocketInterruptablePollWrapper>(socket_))
{
}
void TestKeeperTCPHandler::sendHandshake()
@ -104,34 +218,11 @@ void TestKeeperTCPHandler::runImpl()
while (true)
{
using namespace std::chrono_literals;
while (!responses.empty())
{
if (responses.front().wait_for(10ms) != std::future_status::ready)
break;
auto response = responses.front().get();
response->write(*out);
responses.pop();
}
Poco::Timespan poll_wait = responses.empty() ? session_timeout.totalMicroseconds() - session_stopwatch.elapsedMicroseconds() : session_timeout;
for (auto it = watch_responses.begin(); it != watch_responses.end();)
{
if (it->wait_for(0s) == std::future_status::ready)
{
auto response = it->get();
if (response->error == Coordination::Error::ZOK)
response->write(*out);
it = watch_responses.erase(it);
}
else
{
++it;
}
}
Int64 poll_wait = responses.empty() ? session_timeout.totalMicroseconds() - session_stopwatch.elapsedMicroseconds() : 10000;
if (in->poll(poll_wait))
auto state = poll_wrapper->poll(poll_wait);
if (state == SocketInterruptablePollWrapper::PollStatus::HAS_DATA)
{
auto received_op = receiveRequest();
if (received_op == Coordination::OpNum::Close)
@ -144,6 +235,36 @@ void TestKeeperTCPHandler::runImpl()
session_stopwatch.restart();
}
}
else if (state == SocketInterruptablePollWrapper::PollStatus::INTERRUPTED)
{
while (!responses.empty())
{
if (responses.front().wait_for(0ms) != std::future_status::ready)
break;
auto response = responses.front().get();
response->write(*out);
responses.pop();
}
for (auto it = watch_responses.begin(); it != watch_responses.end();)
{
if (it->wait_for(0s) == std::future_status::ready)
{
auto response = it->get();
if (response->error == Coordination::Error::ZOK)
response->write(*out);
it = watch_responses.erase(it);
}
else
{
++it;
}
}
}
else if (state == SocketInterruptablePollWrapper::PollStatus::ERROR)
{
throw Exception("Exception happened while reading from socket", ErrorCodes::SYSTEM_ERROR);
}
if (session_stopwatch.elapsedMicroseconds() > static_cast<UInt64>(session_timeout.totalMicroseconds()))
{
@ -175,10 +296,30 @@ Coordination::OpNum TestKeeperTCPHandler::receiveRequest()
request->readImpl(*in);
if (opnum != Coordination::OpNum::Close)
{
auto request_future_responses = test_keeper_storage->putRequest(request, session_id);
responses.push(std::move(request_future_responses.response));
if (request_future_responses.watch_response)
watch_responses.emplace_back(std::move(*request_future_responses.watch_response));
auto promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>();
zkutil::ResponseCallback callback = [this, promise] (const Coordination::ZooKeeperResponsePtr & response)
{
promise->set_value(response);
poll_wrapper->interruptPoll();
};
if (request->has_watch)
{
auto watch_promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>();
zkutil::ResponseCallback watch_callback = [this, watch_promise] (const Coordination::ZooKeeperResponsePtr & response)
{
watch_promise->set_value(response);
poll_wrapper->interruptPoll();
};
test_keeper_storage->putRequest(request, session_id, callback, watch_callback);
responses.push(promise->get_future());
watch_responses.emplace_back(watch_promise->get_future());
}
else
{
test_keeper_storage->putRequest(request, session_id, callback);
responses.push(promise->get_future());
}
}
else
{

View File

@ -14,21 +14,13 @@
namespace DB
{
struct SocketInterruptablePollWrapper;
using SocketInterruptablePollWrapperPtr = std::unique_ptr<SocketInterruptablePollWrapper>;
class TestKeeperTCPHandler : public Poco::Net::TCPServerConnection
{
public:
TestKeeperTCPHandler(IServer & server_, const Poco::Net::StreamSocket & socket_)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, log(&Poco::Logger::get("TestKeeperTCPHandler"))
, global_context(server.context())
, test_keeper_storage(global_context.getTestKeeperStorage())
, operation_timeout(0, Coordination::DEFAULT_OPERATION_TIMEOUT_MS * 1000)
, session_timeout(0, Coordination::DEFAULT_SESSION_TIMEOUT_MS * 1000)
, session_id(test_keeper_storage->getSessionID())
{
}
TestKeeperTCPHandler(IServer & server_, const Poco::Net::StreamSocket & socket_);
void run() override;
private:
IServer & server;
@ -39,6 +31,7 @@ private:
Poco::Timespan session_timeout;
int64_t session_id;
Stopwatch session_stopwatch;
SocketInterruptablePollWrapperPtr poll_wrapper;
std::queue<zkutil::TestKeeperStorage::AsyncResponse> responses;
std::vector<zkutil::TestKeeperStorage::AsyncResponse> watch_responses;