Better polling

This commit is contained in:
alesapin 2020-11-26 11:59:23 +03:00
parent 6f6476ad95
commit 0b079cdeb8

View File

@ -28,6 +28,9 @@ namespace ErrorCodes
extern const int UNEXPECTED_PACKET_FROM_CLIENT; extern const int UNEXPECTED_PACKET_FROM_CLIENT;
} }
static constexpr UInt8 RESPONSE_BYTE = 1;
static constexpr UInt8 WATCH_RESPONSE_BYTE = 2;
struct SocketInterruptablePollWrapper struct SocketInterruptablePollWrapper
{ {
int sockfd; int sockfd;
@ -39,13 +42,12 @@ struct SocketInterruptablePollWrapper
epoll_event pipe_event{}; epoll_event pipe_event{};
#endif #endif
enum class PollStatus using PollStatus = size_t;
{ static constexpr PollStatus TIMEOUT = 0x0;
HAS_DATA, static constexpr PollStatus HAS_REQUEST = 0x1;
TIMEOUT, static constexpr PollStatus HAS_RESPONSE = 0x2;
INTERRUPTED, static constexpr PollStatus HAS_WATCH_RESPONSE = 0x4;
ERROR, static constexpr PollStatus ERROR = 0x8;
};
using InterruptCallback = std::function<void()>; using InterruptCallback = std::function<void()>;
@ -76,20 +78,21 @@ struct SocketInterruptablePollWrapper
#endif #endif
} }
int getInterruptFD() const int getResponseFD() const
{ {
return pipe.fds_rw[1]; return pipe.fds_rw[1];
} }
PollStatus poll(Poco::Timespan remaining_time) PollStatus poll(Poco::Timespan remaining_time)
{ {
std::array<int, 2> outputs = {-1, -1};
#if defined(POCO_HAVE_FD_EPOLL) #if defined(POCO_HAVE_FD_EPOLL)
int rc; int rc;
epoll_event evout{}; epoll_event evout[2];
do do
{ {
Poco::Timestamp start; Poco::Timestamp start;
rc = epoll_wait(epollfd, &evout, 1, remaining_time.totalMilliseconds()); rc = epoll_wait(epollfd, evout, 2, remaining_time.totalMilliseconds());
if (rc < 0 && errno == EINTR) if (rc < 0 && errno == EINTR)
{ {
Poco::Timestamp end; Poco::Timestamp end;
@ -102,7 +105,10 @@ struct SocketInterruptablePollWrapper
} }
while (rc < 0 && errno == EINTR); while (rc < 0 && errno == EINTR);
int out_fd = evout.data.fd; if (rc >= 1 && evout[0].events & EPOLLIN)
outputs[0] = evout[0].data.fd;
if (rc == 2 && evout[1].events & EPOLLIN)
outputs[1] = evout[1].data.fd;
#else #else
pollfd poll_buf[2]; pollfd poll_buf[2];
poll_buf[0].fd = sockfd; poll_buf[0].fd = sockfd;
@ -126,24 +132,46 @@ struct SocketInterruptablePollWrapper
} }
} }
while (rc < 0 && errno == POCO_EINTR); while (rc < 0 && errno == POCO_EINTR);
int out_fd = -1; if (rc >= 1 && poll_buf[0].revents & POLLIN)
if (poll_buf[0].revents & POLLIN) outputs[0] = sockfd;
out_fd = sockfd; if (rc == 2 && poll_buf[1].revents & POLLIN)
else if (poll_buf[1].revents & POLLIN) outputs[1] = pipe.fds_rw[0];
out_fd = pipe.fds_rw[0];
#endif #endif
PollStatus result = TIMEOUT;
if (rc < 0) if (rc < 0)
return PollStatus::ERROR;
else if (rc == 0)
return PollStatus::TIMEOUT;
else if (out_fd == pipe.fds_rw[0])
{ {
UInt64 bytes; return ERROR;
if (read(pipe.fds_rw[0], &bytes, sizeof(bytes)) < 0)
throwFromErrno("Cannot read from pipe", ErrorCodes::SYSTEM_ERROR);
return PollStatus::INTERRUPTED;
} }
return PollStatus::HAS_DATA; else if (rc == 0)
{
return result;
}
else
{
for (size_t i = 0; i < outputs.size(); ++i)
{
int fd = outputs[i];
if (fd != -1)
{
if (fd == sockfd)
result |= HAS_REQUEST;
else
{
UInt8 byte;
if (read(pipe.fds_rw[0], &byte, sizeof(byte)) < 0)
throwFromErrno("Cannot read from pipe", ErrorCodes::SYSTEM_ERROR);
if (byte == WATCH_RESPONSE_BYTE)
result |= HAS_WATCH_RESPONSE;
else if (byte == RESPONSE_BYTE)
result |= HAS_RESPONSE;
else
throw Exception("Unexpected byte received from signaling pipe", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT);
}
}
}
}
return result;
} }
#if defined(POCO_HAVE_FD_EPOLL) #if defined(POCO_HAVE_FD_EPOLL)
@ -254,7 +282,7 @@ void TestKeeperTCPHandler::runImpl()
using namespace std::chrono_literals; using namespace std::chrono_literals;
auto state = poll_wrapper->poll(session_timeout); auto state = poll_wrapper->poll(session_timeout);
if (state == SocketInterruptablePollWrapper::PollStatus::HAS_DATA) if (state & SocketInterruptablePollWrapper::HAS_REQUEST)
{ {
auto received_op = receiveRequest(); auto received_op = receiveRequest();
if (received_op == Coordination::OpNum::Close) if (received_op == Coordination::OpNum::Close)
@ -267,17 +295,22 @@ void TestKeeperTCPHandler::runImpl()
session_stopwatch.restart(); session_stopwatch.restart();
} }
} }
else if (state == SocketInterruptablePollWrapper::PollStatus::INTERRUPTED)
if (state & SocketInterruptablePollWrapper::HAS_RESPONSE)
{ {
while (!responses.empty()) while (!responses.empty())
{ {
if (responses.front().wait_for(0ms) != std::future_status::ready) if (responses.front().wait_for(0s) != std::future_status::ready)
break; break;
auto response = responses.front().get(); auto response = responses.front().get();
response->write(*out); response->write(*out);
responses.pop(); responses.pop();
} }
}
if (state & SocketInterruptablePollWrapper::HAS_WATCH_RESPONSE)
{
for (auto it = watch_responses.begin(); it != watch_responses.end();) for (auto it = watch_responses.begin(); it != watch_responses.end();)
{ {
if (it->wait_for(0s) == std::future_status::ready) if (it->wait_for(0s) == std::future_status::ready)
@ -293,7 +326,8 @@ void TestKeeperTCPHandler::runImpl()
} }
} }
} }
else if (state == SocketInterruptablePollWrapper::PollStatus::ERROR)
if (state == SocketInterruptablePollWrapper::ERROR)
{ {
throw Exception("Exception happened while reading from socket", ErrorCodes::SYSTEM_ERROR); throw Exception("Exception happened while reading from socket", ErrorCodes::SYSTEM_ERROR);
} }
@ -326,25 +360,23 @@ Coordination::OpNum TestKeeperTCPHandler::receiveRequest()
Coordination::ZooKeeperRequestPtr request = Coordination::ZooKeeperRequestFactory::instance().get(opnum); Coordination::ZooKeeperRequestPtr request = Coordination::ZooKeeperRequestFactory::instance().get(opnum);
request->xid = xid; request->xid = xid;
request->readImpl(*in); request->readImpl(*in);
int interrupt_fd = poll_wrapper->getInterruptFD(); int response_fd = poll_wrapper->getResponseFD();
if (opnum != Coordination::OpNum::Close) if (opnum != Coordination::OpNum::Close)
{ {
auto promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>(); auto promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>();
zkutil::ResponseCallback callback = [interrupt_fd, promise] (const Coordination::ZooKeeperResponsePtr & response) zkutil::ResponseCallback callback = [response_fd, promise] (const Coordination::ZooKeeperResponsePtr & response)
{ {
promise->set_value(response); promise->set_value(response);
UInt64 bytes = 1; [[maybe_unused]] int result = write(response_fd, &RESPONSE_BYTE, sizeof(RESPONSE_BYTE));
[[maybe_unused]] int result = write(interrupt_fd, &bytes, sizeof(bytes));
}; };
if (request->has_watch) if (request->has_watch)
{ {
auto watch_promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>(); auto watch_promise = std::make_shared<std::promise<Coordination::ZooKeeperResponsePtr>>();
zkutil::ResponseCallback watch_callback = [interrupt_fd, watch_promise] (const Coordination::ZooKeeperResponsePtr & response) zkutil::ResponseCallback watch_callback = [response_fd, watch_promise] (const Coordination::ZooKeeperResponsePtr & response)
{ {
watch_promise->set_value(response); watch_promise->set_value(response);
UInt64 bytes = 1; [[maybe_unused]] int result = write(response_fd, &WATCH_RESPONSE_BYTE, sizeof(WATCH_RESPONSE_BYTE));
[[maybe_unused]] int result = write(interrupt_fd, &bytes, sizeof(bytes));
}; };
test_keeper_storage->putRequest(request, session_id, callback, watch_callback); test_keeper_storage->putRequest(request, session_id, callback, watch_callback);
responses.push(promise->get_future()); responses.push(promise->get_future());