|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# |
| 3 | +# The OpenSearch Contributors require contributions made to |
| 4 | +# this file be licensed under the Apache-2.0 license or a |
| 5 | +# compatible open source license. |
| 6 | +# |
| 7 | +# Modifications Copyright OpenSearch Contributors. See |
| 8 | +# GitHub history for details. |
| 9 | + |
| 10 | +import sys |
| 11 | + |
| 12 | +import requests |
| 13 | + |
| 14 | +OPENSEARCH_SERVICE = "es" |
| 15 | + |
| 16 | +PY3 = sys.version_info[0] == 3 |
| 17 | + |
| 18 | +if PY3: |
| 19 | + from urllib.parse import parse_qs, urlencode, urlparse |
| 20 | + |
| 21 | + |
| 22 | +def fetch_url(prepared_request): # type: ignore |
| 23 | + """ |
| 24 | + This is a util method that helps in reconstructing the request url. |
| 25 | + :param prepared_request: unsigned request |
| 26 | + :return: reconstructed url |
| 27 | + """ |
| 28 | + url = urlparse(prepared_request.url) |
| 29 | + path = url.path or "/" |
| 30 | + |
| 31 | + # fetch the query string if present in the request |
| 32 | + querystring = "" |
| 33 | + if url.query: |
| 34 | + querystring = "?" + urlencode( |
| 35 | + parse_qs(url.query, keep_blank_values=True), doseq=True |
| 36 | + ) |
| 37 | + |
| 38 | + # fetch the host information from headers |
| 39 | + headers = dict( |
| 40 | + (key.lower(), value) for key, value in prepared_request.headers.items() |
| 41 | + ) |
| 42 | + location = headers.get("host") or url.netloc |
| 43 | + |
| 44 | + # construct the url and return |
| 45 | + return url.scheme + "://" + location + path + querystring |
| 46 | + |
| 47 | + |
| 48 | +class AWSV4SignerAuth(requests.auth.AuthBase): |
| 49 | + """ |
| 50 | + AWS V4 Request Signer for Requests. |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__(self, credentials, region): # type: ignore |
| 54 | + if not credentials: |
| 55 | + raise ValueError("Credentials cannot be empty") |
| 56 | + self.credentials = credentials |
| 57 | + |
| 58 | + if not region: |
| 59 | + raise ValueError("Region cannot be empty") |
| 60 | + self.region = region |
| 61 | + |
| 62 | + def __call__(self, request): # type: ignore |
| 63 | + return self._sign_request(request) # type: ignore |
| 64 | + |
| 65 | + def _sign_request(self, prepared_request): # type: ignore |
| 66 | + """ |
| 67 | + This method helps in signing the request by injecting the required headers. |
| 68 | + :param prepared_request: unsigned request |
| 69 | + :return: signed request |
| 70 | + """ |
| 71 | + |
| 72 | + from botocore.auth import SigV4Auth |
| 73 | + from botocore.awsrequest import AWSRequest |
| 74 | + |
| 75 | + url = fetch_url(prepared_request) # type: ignore |
| 76 | + |
| 77 | + # create an AWS request object and sign it using SigV4Auth |
| 78 | + aws_request = AWSRequest( |
| 79 | + method=prepared_request.method.upper(), |
| 80 | + url=url, |
| 81 | + data=prepared_request.body, |
| 82 | + ) |
| 83 | + sig_v4_auth = SigV4Auth(self.credentials, OPENSEARCH_SERVICE, self.region) |
| 84 | + sig_v4_auth.add_auth(aws_request) |
| 85 | + |
| 86 | + # copy the headers from AWS request object into the prepared_request |
| 87 | + prepared_request.headers.update(dict(aws_request.headers.items())) |
| 88 | + |
| 89 | + return prepared_request |
0 commit comments