|
| 1 | +import asyncio |
| 2 | +import ssl |
| 3 | +import warnings |
| 4 | + |
| 5 | +import aiohttp |
| 6 | +from aiohttp.client_exceptions import ServerFingerprintMismatch |
| 7 | +import async_timeout |
| 8 | + |
| 9 | +from elasticsearch.exceptions import ConnectionError, ConnectionTimeout, ImproperlyConfigured, SSLError |
| 10 | +from elasticsearch.connection import Connection |
| 11 | +from elasticsearch.compat import urlencode |
| 12 | +from elasticsearch.connection.http_urllib3 import create_ssl_context |
| 13 | + |
| 14 | + |
| 15 | +# This is only needed because https://github.com/elastic/elasticsearch-py-async/pull/68 is not merged yet |
| 16 | +# In addition we have raised the connection limit in TCPConnector from 100 to 10000. |
| 17 | + |
| 18 | +# We want to keep the diff as small as possible thus suppressing pylint warnings that we would not allow in Rally |
| 19 | +# pylint: disable=W0706 |
| 20 | +class AIOHttpConnection(Connection): |
| 21 | + def __init__(self, host='localhost', port=9200, http_auth=None, |
| 22 | + use_ssl=False, verify_certs=False, ca_certs=None, client_cert=None, |
| 23 | + client_key=None, loop=None, use_dns_cache=True, headers=None, |
| 24 | + ssl_context=None, trace_config=None, **kwargs): |
| 25 | + super().__init__(host=host, port=port, **kwargs) |
| 26 | + |
| 27 | + self.loop = asyncio.get_event_loop() if loop is None else loop |
| 28 | + |
| 29 | + if http_auth is not None: |
| 30 | + if isinstance(http_auth, str): |
| 31 | + http_auth = tuple(http_auth.split(':', 1)) |
| 32 | + |
| 33 | + if isinstance(http_auth, (tuple, list)): |
| 34 | + http_auth = aiohttp.BasicAuth(*http_auth) |
| 35 | + |
| 36 | + headers = headers or {} |
| 37 | + headers.setdefault('content-type', 'application/json') |
| 38 | + |
| 39 | + # if providing an SSL context, raise error if any other SSL related flag is used |
| 40 | + if ssl_context and (verify_certs or ca_certs): |
| 41 | + raise ImproperlyConfigured("When using `ssl_context`, `use_ssl`, `verify_certs`, `ca_certs` are not permitted") |
| 42 | + |
| 43 | + if use_ssl or ssl_context: |
| 44 | + cafile = ca_certs |
| 45 | + if not cafile and not ssl_context and verify_certs: |
| 46 | + # If no ca_certs and no sslcontext passed and asking to verify certs |
| 47 | + # raise error |
| 48 | + raise ImproperlyConfigured("Root certificates are missing for certificate " |
| 49 | + "validation. Either pass them in using the ca_certs parameter or " |
| 50 | + "install certifi to use it automatically.") |
| 51 | + if verify_certs or ca_certs: |
| 52 | + warnings.warn('Use of `verify_certs`, `ca_certs` have been deprecated in favor of using SSLContext`', DeprecationWarning) |
| 53 | + |
| 54 | + if not ssl_context: |
| 55 | + # if SSLContext hasn't been passed in, create one. |
| 56 | + # need to skip if sslContext isn't avail |
| 57 | + try: |
| 58 | + ssl_context = create_ssl_context(cafile=cafile) |
| 59 | + except AttributeError: |
| 60 | + ssl_context = None |
| 61 | + |
| 62 | + if not verify_certs and ssl_context is not None: |
| 63 | + ssl_context.check_hostname = False |
| 64 | + ssl_context.verify_mode = ssl.CERT_NONE |
| 65 | + warnings.warn( |
| 66 | + 'Connecting to %s using SSL with verify_certs=False is insecure.' % host) |
| 67 | + if ssl_context: |
| 68 | + verify_certs = True |
| 69 | + use_ssl = True |
| 70 | + |
| 71 | + trace_configs = [trace_config] if trace_config else None |
| 72 | + |
| 73 | + self.session = aiohttp.ClientSession( |
| 74 | + auth=http_auth, |
| 75 | + timeout=self.timeout, |
| 76 | + connector=aiohttp.TCPConnector( |
| 77 | + loop=self.loop, |
| 78 | + verify_ssl=verify_certs, |
| 79 | + use_dns_cache=use_dns_cache, |
| 80 | + ssl_context=ssl_context, |
| 81 | + # this has been changed from the default (100) |
| 82 | + limit=100000 |
| 83 | + ), |
| 84 | + headers=headers, |
| 85 | + trace_configs=trace_configs |
| 86 | + ) |
| 87 | + |
| 88 | + self.base_url = 'http%s://%s:%d%s' % ( |
| 89 | + 's' if use_ssl else '', |
| 90 | + host, port, self.url_prefix |
| 91 | + ) |
| 92 | + |
| 93 | + @asyncio.coroutine |
| 94 | + def close(self): |
| 95 | + yield from self.session.close() |
| 96 | + |
| 97 | + @asyncio.coroutine |
| 98 | + def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None): |
| 99 | + url_path = url |
| 100 | + if params: |
| 101 | + url_path = '%s?%s' % (url, urlencode(params or {})) |
| 102 | + url = self.base_url + url_path |
| 103 | + |
| 104 | + start = self.loop.time() |
| 105 | + response = None |
| 106 | + try: |
| 107 | + with async_timeout.timeout(timeout or self.timeout.total, loop=self.loop): |
| 108 | + response = yield from self.session.request(method, url, data=body, headers=headers) |
| 109 | + raw_data = yield from response.text() |
| 110 | + duration = self.loop.time() - start |
| 111 | + |
| 112 | + except asyncio.CancelledError: |
| 113 | + raise |
| 114 | + |
| 115 | + except Exception as e: |
| 116 | + self.log_request_fail(method, url, url_path, body, self.loop.time() - start, exception=e) |
| 117 | + if isinstance(e, ServerFingerprintMismatch): |
| 118 | + raise SSLError('N/A', str(e), e) |
| 119 | + if isinstance(e, asyncio.TimeoutError): |
| 120 | + raise ConnectionTimeout('TIMEOUT', str(e), e) |
| 121 | + raise ConnectionError('N/A', str(e), e) |
| 122 | + |
| 123 | + finally: |
| 124 | + if response is not None: |
| 125 | + yield from response.release() |
| 126 | + |
| 127 | + # raise errors based on http status codes, let the client handle those if needed |
| 128 | + if not (200 <= response.status < 300) and response.status not in ignore: |
| 129 | + self.log_request_fail(method, url, url_path, body, duration, status_code=response.status, response=raw_data) |
| 130 | + self._raise_error(response.status, raw_data) |
| 131 | + |
| 132 | + self.log_request_success(method, url, url_path, body, response.status, raw_data, duration) |
| 133 | + |
| 134 | + return response.status, response.headers, raw_data |
0 commit comments