Fix handling of SSL_ERROR_WANT_READ/SSL_ERROR_WANT_WRITE with zero timeout

Previously if you were using socket without timeout it wasn't able to
handle SSL_ERROR_WANT_READ/SSL_ERROR_WANT_WRITE, and even though sockets
without timeouts is an odd thing (but it is possible - [1]), it still
may be possible somewhere.

  [1]: https://github.com/ClickHouse/ClickHouse/pull/65917

Signed-off-by: Azat Khuzhin <a.khuzhin@semrush.com>
This commit is contained in:
Azat Khuzhin 2024-07-01 19:17:06 +02:00
parent 63b031479c
commit ef24f51789
2 changed files with 24 additions and 12 deletions

View File

@ -235,8 +235,6 @@ namespace Net
/// Note that simply closing a socket is not sufficient /// Note that simply closing a socket is not sufficient
/// to be able to re-use it again. /// to be able to re-use it again.
Poco::Timespan getMaxTimeout();
private: private:
SecureSocketImpl(const SecureSocketImpl &); SecureSocketImpl(const SecureSocketImpl &);
SecureSocketImpl & operator=(const SecureSocketImpl &); SecureSocketImpl & operator=(const SecureSocketImpl &);
@ -250,6 +248,9 @@ namespace Net
Session::Ptr _pSession; Session::Ptr _pSession;
friend class SecureStreamSocketImpl; friend class SecureStreamSocketImpl;
Poco::Timespan getMaxTimeoutOrLimit();
//// Return max(send, receive) if non zero, otherwise maximum timeout
}; };

View File

@ -199,7 +199,7 @@ void SecureSocketImpl::connectSSL(bool performHandshake)
if (performHandshake && _pSocket->getBlocking()) if (performHandshake && _pSocket->getBlocking())
{ {
int ret; int ret;
Poco::Timespan remaining_time = getMaxTimeout(); Poco::Timespan remaining_time = getMaxTimeoutOrLimit();
do do
{ {
RemainingTimeCounter counter(remaining_time); RemainingTimeCounter counter(remaining_time);
@ -302,7 +302,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
return rc; return rc;
} }
Poco::Timespan remaining_time = getMaxTimeout(); Poco::Timespan remaining_time = getMaxTimeoutOrLimit();
do do
{ {
RemainingTimeCounter counter(remaining_time); RemainingTimeCounter counter(remaining_time);
@ -338,7 +338,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
return rc; return rc;
} }
Poco::Timespan remaining_time = getMaxTimeout(); Poco::Timespan remaining_time = getMaxTimeoutOrLimit();
do do
{ {
/// SSL record may consist of several TCP packets, /// SSL record may consist of several TCP packets,
@ -372,7 +372,7 @@ int SecureSocketImpl::completeHandshake()
poco_check_ptr (_pSSL); poco_check_ptr (_pSSL);
int rc; int rc;
Poco::Timespan remaining_time = getMaxTimeout(); Poco::Timespan remaining_time = getMaxTimeoutOrLimit();
do do
{ {
RemainingTimeCounter counter(remaining_time); RemainingTimeCounter counter(remaining_time);
@ -453,18 +453,29 @@ X509* SecureSocketImpl::peerCertificate() const
return 0; return 0;
} }
Poco::Timespan SecureSocketImpl::getMaxTimeout() Poco::Timespan SecureSocketImpl::getMaxTimeoutOrLimit()
{ {
std::lock_guard<std::recursive_mutex> lock(_mutex); std::lock_guard<std::recursive_mutex> lock(_mutex);
Poco::Timespan remaining_time = _pSocket->getReceiveTimeout(); Poco::Timespan remaining_time = _pSocket->getReceiveTimeout();
Poco::Timespan send_timeout = _pSocket->getSendTimeout(); Poco::Timespan send_timeout = _pSocket->getSendTimeout();
if (remaining_time < send_timeout) if (remaining_time < send_timeout)
remaining_time = send_timeout; remaining_time = send_timeout;
/// zero SO_SNDTIMEO/SO_RCVTIMEO works as no timeout, let's replicate this
///
/// NOTE: we cannot use INT64_MAX (std::numeric_limits<Poco::Timespan::TimeDiff>::max()),
/// since it will be later passed to poll() which accept int timeout, and
/// even though poll() accepts milliseconds and Timespan() accepts
/// microseconds, let's use smaller maximum value just to avoid some possible
/// issues, this should be enough anyway (it is ~24 days).
if (remaining_time == 0)
remaining_time = Poco::Timespan(std::numeric_limits<int>::max());
return remaining_time; return remaining_time;
} }
bool SecureSocketImpl::mustRetry(int rc, Poco::Timespan& remaining_time) bool SecureSocketImpl::mustRetry(int rc, Poco::Timespan& remaining_time)
{ {
if (remaining_time == 0)
return false;
std::lock_guard<std::recursive_mutex> lock(_mutex); std::lock_guard<std::recursive_mutex> lock(_mutex);
if (rc <= 0) if (rc <= 0)
{ {
@ -475,9 +486,7 @@ bool SecureSocketImpl::mustRetry(int rc, Poco::Timespan& remaining_time)
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
if (_pSocket->getBlocking()) if (_pSocket->getBlocking())
{ {
/// Level-triggered mode of epoll_wait is used, so if SSL_read don't read all available data from socket, if (_pSocket->pollImpl(remaining_time, Poco::Net::Socket::SELECT_READ))
/// epoll_wait returns true without waiting for new data even if remaining_time == 0
if (_pSocket->pollImpl(remaining_time, Poco::Net::Socket::SELECT_READ) && remaining_time != 0)
return true; return true;
else else
throw Poco::TimeoutException(); throw Poco::TimeoutException();
@ -486,13 +495,15 @@ bool SecureSocketImpl::mustRetry(int rc, Poco::Timespan& remaining_time)
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
if (_pSocket->getBlocking()) if (_pSocket->getBlocking())
{ {
/// The same as for SSL_ERROR_WANT_READ if (_pSocket->pollImpl(remaining_time, Poco::Net::Socket::SELECT_WRITE))
if (_pSocket->pollImpl(remaining_time, Poco::Net::Socket::SELECT_WRITE) && remaining_time != 0)
return true; return true;
else else
throw Poco::TimeoutException(); throw Poco::TimeoutException();
} }
break; break;
/// NOTE: POCO_EINTR is the same as SSL_ERROR_WANT_READ (at least in
/// OpenSSL), so this likely dead code, but let's leave it for
/// compatibility with other implementations
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
return socketError == POCO_EAGAIN || socketError == POCO_EINTR; return socketError == POCO_EAGAIN || socketError == POCO_EINTR;
default: default: