import collections
import copy
import logging
import random
import struct
import time
import kafka.errors as Errors
from kafka.future import Future
from kafka.net.metrics import KafkaConnectionMetrics
from kafka.protocol.metadata import ApiVersionsRequest
from kafka.protocol.sasl import SaslAuthenticateRequest, SaslHandshakeRequest, SaslBytesRequest
from kafka.protocol.broker_version_data import BrokerVersionData
from kafka.protocol.parser import KafkaProtocol
from kafka.sasl import get_sasl_mechanism
from kafka.version import __version__
log = logging.getLogger(__name__)
[docs]
class KafkaConnection:
DEFAULT_CONFIG = {
'client_id': 'kafka-python-' + __version__,
'client_software_name': 'kafka-python',
'client_software_version': __version__,
'max_in_flight_requests_per_connection': 5,
'receive_message_max_bytes': 1000000,
'request_timeout_ms': 30000,
'security_protocol': 'PLAINTEXT',
'sasl_mechanism': None,
'sasl_plain_username': None,
'sasl_plain_password': None,
'sasl_kerberos_name': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None,
'metrics': None,
'metric_group_prefix': '',
}
def __init__(self, net, node_id=None, broker_version_data=None, **configs):
self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
if key in configs:
self.config[key] = configs[key]
self.node_id = node_id
self.net = net
self.transport = None
self.parser = None
self._request_buffer = collections.deque()
self.paused = set()
self.connected = False
self.initializing = True
self._init_future = Future()
self._close_future = Future()
self.in_flight_requests = collections.deque()
self.broker_version_data = broker_version_data
self._api_versions_idx = ApiVersionsRequest.max_version # version of ApiVersionsRequest to try on first connect
self._throttle_time = 0
self._reauth = SaslReauthenticator(self)
if self.config['metrics']:
self._sensors = KafkaConnectionMetrics(
self.config['metrics'], self.config['metric_group_prefix'], node_id)
else:
self._sensors = None
self._init_future.add_errback(self.fail_in_flight_requests)
self._close_future.add_both(self.fail_in_flight_requests)
@property
def broker_version(self):
if self.broker_version_data is None:
return None
return self.broker_version_data.broker_version
@property
def closed(self):
return not self.connected and not self.initializing
def __str__(self):
if self.initializing:
state = 'initializing'
elif not self.connected:
state = 'disconnected'
elif self.paused:
state = 'paused'
else:
state = 'connected'
host_port = ' host=[%s]' % self.transport.host_port() if self.transport else ''
broker_version = self.broker_version if self.broker_version is not None else 'unknown'
return f'<KafkaConnection node_id={self.node_id}{host_port} broker_version={broker_version} ({state})>'
@property
def init_future(self):
return self._init_future
def __await__(self):
yield self.init_future
return self
@property
def close_future(self):
return self._close_future
def _timeout_at(self, now=None, timeout_ms=None):
if now is None:
now = time.monotonic()
if timeout_ms is not None:
return now + timeout_ms / 1000
else:
try:
return now + self._timeout_secs
except AttributeError:
self._timeout_secs = self.config['request_timeout_ms'] / 1000
return now + self._timeout_secs
def send_request(self, request, request_timeout_ms=None):
future = Future()
timeout_at = self._timeout_at(timeout_ms=request_timeout_ms)
if self.initializing or self._reauth.is_reauthenticating:
self._request_buffer.append((request, future, timeout_at))
return future
elif self.paused:
return future.failure(Errors.NodeNotReadyError(f'Node paused: {self.paused}'))
elif not self.connected:
return future.failure(Errors.KafkaConnectionError('Node not connected'))
else:
self._send_request(request, future=future, timeout_at=timeout_at)
return future
def _send_request(self, request, future=None, timeout_at=None):
if future is None:
future = Future()
if self.closed:
return future.failure(Errors.KafkaConnectionError('closed'))
if request.API_VERSION is None:
try:
request.API_VERSION = self.broker_version_data.api_version(request)
except Errors.IncompatibleBrokerVersion as exc:
future.failure(exc)
return future
sent_time = time.monotonic()
if timeout_at is None:
timeout_at = self._timeout_at(now=sent_time)
if timeout_at <= sent_time:
future.failure(Errors.KafkaTimeoutError())
return future
correlation_id = self.parser.send_request(request)
log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
# Each in-flight request owns its own timer so heterogeneous
# per-request timeouts (e.g. JoinGroup with a rebalance-sized
# deadline interleaved with default-timeout MetadataRequests)
# don't require monotonic-deadline FIFO ordering.
timeout_task = self.net.call_at(
timeout_at,
lambda: self._request_timed_out(future, sent_time, timeout_at))
self.in_flight_requests.append(
(correlation_id, future, sent_time, timeout_at, timeout_task))
else:
future.success(None)
# Write the current request's bytes before checking max_in_flight.
# Otherwise with max_in_flight=1, the first request would be added to
# in_flight_requests (len==1), trip the >= check, pause, and never be
# written to the transport - hanging forever.
if not self.paused:
self.transport.write(self.parser.send_bytes())
if len(self.in_flight_requests) >= self.config['max_in_flight_requests_per_connection']:
self.pause('max_in_flight')
return future
def send_buffered(self):
while self._request_buffer:
request, future, timeout_at = self._request_buffer.popleft()
self._send_request(request, future=future, timeout_at=timeout_at)
def _request_timed_out(self, future, sent_at, timeout_at):
# Defensive: a response and its timer can both be dispatched within a
# single _poll_once iteration; if data_received resolved the future
# first, skip the connection-close.
if self.closed or future.is_done:
return
timeout_ms = (timeout_at - sent_at) * 1000
log.warning('%s: Request timed out after %d ms. Closing connection.', self, timeout_ms)
self.close(Errors.RequestTimedOutError('Request timed out after %d ms' % timeout_ms))
[docs]
def data_received(self, data):
""" Called when some data is received."""
if self.closed:
log.debug('%s: Ignoring %d bytes received by closed connection', self, len(data))
return
responses = self.parser.receive_bytes(data)
# augment responses w/ correlation_id, future, and timestamp
for i, (resp_correlation_id, response) in enumerate(responses):
try:
(req_correlation_id, future, sent_time, _timeout_at, timeout_task) = self.in_flight_requests.popleft()
except IndexError:
return self.close(Errors.KafkaConnectionError('Received response with no in-flight-requests!'))
if req_correlation_id != resp_correlation_id:
return self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
self.net.unschedule(timeout_task)
latency_ms = (time.monotonic() - sent_time) * 1000
if self._sensors:
self._sensors.request_time.record(latency_ms)
log.debug('%s: Response %d (%s ms): %s', self, resp_correlation_id, latency_ms, response)
self._maybe_throttle(response)
future.success(response)
if 'max_in_flight' in self.paused and len(self.in_flight_requests) < self.config['max_in_flight_requests_per_connection']:
self.unpause('max_in_flight')
self._reauth.on_response_processed()
[docs]
def eof_received(self):
""" Called when the other end calls write_eof() or equivalent.
If this returns a false value (including None), the transport
will close itself. If it returns a true value, closing the
transport is up to the protocol.
"""
return False
[docs]
def connection_lost(self, exc):
""" Called when the connection is lost or closed.
The argument is an exception object or None (the latter
meaning a regular EOF is received or the connection was
aborted or closed).
"""
self.connected = self.initializing = False
self.transport = None
self._reauth.cancel()
error = exc or Errors.KafkaConnectionError()
if not self._init_future.is_done:
self._init_future.failure(error)
if not self._close_future.is_done:
if exc is None:
self._close_future.success(None)
else:
self._close_future.failure(exc)
def fail_in_flight_requests(self, error):
if not self.closed:
raise RuntimeError('Connection must be closed to fail in flight requests')
error = error or Errors.Cancelled()
while self._request_buffer:
_, future, _ = self._request_buffer.popleft()
future.failure(error)
while self.in_flight_requests:
_, future, _, _, timeout_task = self.in_flight_requests.popleft()
self.net.unschedule(timeout_task)
future.failure(error)
[docs]
def connection_made(self, transport):
""" Called when a connection is made.
The argument is the transport representing the pipe connection.
To receive data, wait for data_received() calls.
When the connection is closed, connection_lost() is called.
"""
self.transport = transport
if self.transport.get_protocol() != self:
self.transport.set_protocol(self)
self.initializing = True
self.transport.resume_reading()
log_prefix = 'node=%s[%s:%s]' % (self.node_id, *self.transport.getPeer())
self.parser = KafkaProtocol(
client_id=self.config['client_id'],
receive_message_max_bytes=self.config['receive_message_max_bytes'],
ident=log_prefix)
def pause(self, v):
self.paused.add(v)
def unpause(self, v):
try:
self.paused.remove(v)
except KeyError:
pass
else:
if not self.paused and self.parser and self.transport:
to_send = self.parser.send_bytes()
if to_send:
self.transport.write(to_send)
[docs]
def pause_writing(self):
""" Called when the transport's buffer goes over the high-water mark.
Pause and resume calls are paired -- pause_writing() is called
once when the buffer goes strictly over the high-water mark
(even if subsequent writes increases the buffer size even
more), and eventually resume_writing() is called once when the
buffer size reaches the low-water mark.
Note that if the buffer size equals the high-water mark,
pause_writing() is not called -- it must go strictly over.
Conversely, resume_writing() is called when the buffer size is
equal or lower than the low-water mark. These end conditions
are important to ensure that things go as expected when either
mark is zero.
NOTE: This is the only Protocol callback that is not called
through EventLoop.call_soon() -- if it were, it would have no
effect when it's most needed (when the app keeps writing
without yielding until pause_writing() is called).
"""
self.pause('buffer')
[docs]
def resume_writing(self):
""" Called when the transport's buffer drains below the low-water mark."""
self.unpause('buffer')
def close(self, error=None):
if error is None and not self._init_future.is_done:
error = Errors.KafkaConnectionError()
if not self.transport:
self.connection_lost(error)
return
if error:
self.transport.abort(error)
else:
self.transport.close()
def _maybe_throttle(self, response):
throttle_time_ms = getattr(response, 'throttle_time_ms', 0)
if self._sensors:
self._sensors.throttle_time.record(throttle_time_ms)
if not throttle_time_ms:
return
# Client side throttling enabled in v2.0 brokers
# prior to that throttling (if present) was managed broker-side
if self.broker_version is not None and self.broker_version >= (2, 0):
throttle_time = time.monotonic() + throttle_time_ms / 1000
if throttle_time > self._throttle_time:
self._throttle_time = throttle_time
self.net.call_at(throttle_time, self._maybe_unthrottle)
self.pause('throttle')
log.warning("%s: %s throttled by broker (%d ms)", self,
response.__class__.__name__, throttle_time_ms)
def _maybe_unthrottle(self):
if time.monotonic() >= self._throttle_time:
self._throttle_time = 0
self.unpause('throttle')
async def initialize(self, timeout_at=None):
if timeout_at is None:
timeout_at = self._timeout_at()
try:
await self._get_api_versions(timeout_at)
if self.sasl_enabled:
await self._sasl_authenticate(timeout_at)
except Exception as error:
self.close(error)
else:
self._init_complete()
async def _get_api_versions(self, timeout_at=None):
if timeout_at is None:
timeout_at = self._timeout_at()
if self.broker_version_data is not None:
try:
self._api_versions_idx = self.broker_version_data.api_version(ApiVersionsRequest)
except Errors.IncompatibleBrokerVersion:
log.debug('%s: Using pre-configured api_version %s for ApiVersions', self, self.broker_version)
return
while timeout_at > time.monotonic():
version = self._api_versions_idx
request = ApiVersionsRequest(
version=version,
client_software_name=self.config['client_software_name'],
client_software_version=self.config['client_software_version'],
)
response = await self._send_request(request, timeout_at=timeout_at)
error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
break
elif error_type is Errors.UnsupportedVersionError:
for api_version in response.api_keys:
if api_version.api_key == response.API_KEY:
self._api_versions_idx = min(self._api_versions_idx, api_version.max_version)
break
else:
self._api_versions_idx = 0
continue
else:
raise error_type()
else:
raise Errors.KafkaTimeoutError('Timeout during ApiVersions check')
api_versions = {api_version.api_key: (api_version.min_version, api_version.max_version)
for api_version in response.api_keys}
self.broker_version_data = BrokerVersionData(api_versions=api_versions)
log.info('%s: Broker version identified as %s', self, '.'.join(map(str, self.broker_version)))
@property
def sasl_enabled(self):
return self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
async def _sasl_authenticate(self, timeout_at=None):
if timeout_at is None:
timeout_at = self._timeout_at()
# Step 1: SaslHandshake to negotiate mechanism
request = SaslHandshakeRequest(
mechanism=self.config['sasl_mechanism'],
max_version=1)
response = await self._send_request(request, timeout_at=timeout_at)
error_type = Errors.for_code(response.error_code)
if error_type is not Errors.NoError:
log.error('%s: SaslHandshake failed: %s', self, error_type.__name__)
raise error_type()
if self.config['sasl_mechanism'] not in response.mechanisms:
raise Errors.UnsupportedSaslMechanismError(
'Kafka broker does not support %s sasl mechanism. Enabled mechanisms: %s'
% (self.config['sasl_mechanism'], response.mechanisms))
# Step 2: SASL authentication exchange
version = response.API_VERSION
# Prefer the configured hostname (stored on the transport) so that
# mechanisms like GSSAPI construct service principals against the
# user-supplied name, not whichever IP getaddrinfo handed us.
sasl_host = self.transport.host if self.transport.host else self.transport.getPeer()[0]
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
host=sasl_host, **self.config)
auth_response = None
while not mechanism.is_done() and timeout_at > time.monotonic():
token = mechanism.auth_bytes()
if version == 1:
auth_request = SaslAuthenticateRequest(token)
else:
auth_request = SaslBytesRequest(token)
auth_response = await self._send_request(auth_request, timeout_at=timeout_at)
error_type = Errors.for_code(auth_response.error_code)
if error_type is not Errors.NoError:
raise Errors.SaslAuthenticationFailedError(
'%s: %s' % (error_type.__name__, auth_response.error_message))
# GSSAPI does not get a final recv in v0 unframed mode
if version == 0 and mechanism.is_done():
break
mechanism.receive(auth_response.auth_bytes)
if time.monotonic() > timeout_at:
raise Errors.KafkaTimeoutError('SASL Authentication timed out')
elif not mechanism.is_authenticated():
raise Errors.SaslAuthenticationFailedError(
'Failed to authenticate via SASL %s' % self.config['sasl_mechanism'])
# KIP-368: SessionLifetimeMs is only present on SaslAuthenticateResponse v1+.
if version == 1:
self._reauth.session_updated(auth_response.session_lifetime_ms)
log.info('%s: %s', self, mechanism.auth_details())
def _init_complete(self):
if self.initializing:
self.initializing = False
self.connected = True
self.send_buffered()
self._init_future.success(True)
self._reauth.schedule()
class SaslReauthenticator:
"""KIP-368 SASL re-authentication state and scheduling for a single
KafkaConnection. Owns the per-connection re-auth lifecycle so the
connection doesn't have to carry the related attributes and coroutines
inline. The connection plugs this in at five points:
- after each successful SASL auth -> session_updated()
- after init completes -> schedule()
- when send_request needs to gate the public API -> is_reauthenticating
- on every response popped from in_flight_requests -> on_response_processed()
- on connection_lost -> cancel()
"""
def __init__(self, conn):
self._conn = conn
self.session_lifetime_ms = 0
self.authenticated_at = None
self._task = None
self._reauthenticating = False
self._drain_future = None
@property
def is_reauthenticating(self):
return self._reauthenticating
@property
def task(self):
"""The scheduled re-auth task, or None. Exposed for tests/observability."""
return self._task
def session_updated(self, session_lifetime_ms):
"""Capture broker-advertised session lifetime after each successful
auth round (initial and subsequent re-auths). Clamp negative values to 0,
and require minimum non-zero lifetime of 1sec (1000)."""
self.session_lifetime_ms = session_lifetime_ms or 0
if self.session_lifetime_ms < 0:
self.session_lifetime_ms = 0
elif 0 < self.session_lifetime_ms <= 1000:
self.session_lifetime_ms = 1000
self.authenticated_at = time.monotonic()
def schedule(self):
"""Schedule the next re-auth before the lifetime elapses. Jittered to
85-95% of the lifetime to avoid synchronised re-auth storms across
many connections (Apache Java semantics). No-op when SASL is disabled
or the broker advertised lifetime=0.
"""
if not self._conn.sasl_enabled or not self.session_lifetime_ms:
return
pct = random.uniform(0.85, 0.95)
delay = (self.session_lifetime_ms * pct) / 1000
log.debug('%s: Scheduling SASL re-authentication in %.3fs (session_lifetime_ms=%d)',
self._conn, delay, self.session_lifetime_ms)
self._task = self._conn.net.call_later(delay, self._run)
def cancel(self):
"""Cancel any pending re-auth and fail the drain awaiter if present.
Called from KafkaConnection.connection_lost."""
if self._task is not None:
try:
self._conn.net.unschedule(self._task)
except (ValueError, KeyError):
pass
self._task = None
if self._drain_future is not None and not self._drain_future.is_done:
self._drain_future.failure(Errors.KafkaConnectionError())
self._drain_future = None
self._reauthenticating = False
def on_response_processed(self):
"""Wake the drain awaiter once in_flight_requests clears during reauth.
Called from KafkaConnection.data_received after each pop."""
if (self._reauthenticating
and self._drain_future is not None
and not self._conn.in_flight_requests
and not self._drain_future.is_done):
self._drain_future.success(None)
async def _run(self):
self._task = None
if self._conn.closed:
return
try:
await self._do_reauth()
except BaseException as exc: # pylint: disable=W0718
# Re-auth failure is transient (KIP-368: not cached like initial
# auth failure); close the connection so the manager reconnects on
# next demand.
log.warning('%s: SASL re-authentication failed: %s', self._conn, exc)
err = exc if isinstance(exc, Exception) else Errors.SaslAuthenticationFailedError(str(exc))
self._conn.close(err)
async def _do_reauth(self):
self._reauthenticating = True
try:
# Drain in-flight so the SaslHandshake/Authenticate frames are the
# next bytes on the wire (Apache Java does the same; avoids
# reasoning about FIFO interleaving with the broker's reauth
# validation).
while self._conn.in_flight_requests and not self._conn.closed:
self._drain_future = Future()
if not self._conn.in_flight_requests:
break
await self._drain_future
self._drain_future = None
if self._conn.closed:
return
log.debug('%s: Beginning SASL re-authentication', self._conn)
await self._conn._sasl_authenticate() # pylint: disable=W0212
finally:
self._reauthenticating = False
self._drain_future = None
if self._conn.closed:
return
self._conn.send_buffered()
self.schedule()