17
17
from .logger import RoundTrip
18
18
from .retries import retried
19
19
20
- logger = logging .getLogger (' databricks.sdk' )
20
+ logger = logging .getLogger (" databricks.sdk" )
21
21
22
22
23
23
def _fix_host_if_needed (host : Optional [str ]) -> Optional [str ]:
24
24
if not host :
25
25
return host
26
26
27
27
# Add a default scheme if it's missing
28
- if ' ://' not in host :
29
- host = ' https://' + host
28
+ if " ://" not in host :
29
+ host = " https://" + host
30
30
31
31
o = urllib .parse .urlparse (host )
32
32
# remove trailing slash
33
- path = o .path .rstrip ('/' )
33
+ path = o .path .rstrip ("/" )
34
34
# remove port if 443
35
35
netloc = o .netloc
36
36
if o .port == 443 :
37
- netloc = netloc .split (':' )[0 ]
37
+ netloc = netloc .split (":" )[0 ]
38
38
39
39
return urllib .parse .urlunparse ((o .scheme , netloc , path , o .params , o .query , o .fragment ))
40
40
41
41
42
42
class _BaseClient :
43
43
44
- def __init__ (self ,
45
- debug_truncate_bytes : int = None ,
46
- retry_timeout_seconds : int = None ,
47
- user_agent_base : str = None ,
48
- header_factory : Callable [[], dict ] = None ,
49
- max_connection_pools : int = None ,
50
- max_connections_per_pool : int = None ,
51
- pool_block : bool = True ,
52
- http_timeout_seconds : float = None ,
53
- extra_error_customizers : List [_ErrorCustomizer ] = None ,
54
- debug_headers : bool = False ,
55
- clock : Clock = None ,
56
- streaming_buffer_size : int = 1024 * 1024 ): # 1MB
44
+ def __init__ (
45
+ self ,
46
+ debug_truncate_bytes : int = None ,
47
+ retry_timeout_seconds : int = None ,
48
+ user_agent_base : str = None ,
49
+ header_factory : Callable [[], dict ] = None ,
50
+ max_connection_pools : int = None ,
51
+ max_connections_per_pool : int = None ,
52
+ pool_block : bool = True ,
53
+ http_timeout_seconds : float = None ,
54
+ extra_error_customizers : List [_ErrorCustomizer ] = None ,
55
+ debug_headers : bool = False ,
56
+ clock : Clock = None ,
57
+ streaming_buffer_size : int = 1024 * 1024 ,
58
+ ): # 1MB
57
59
"""
58
60
:param debug_truncate_bytes:
59
61
:param retry_timeout_seconds:
@@ -87,9 +89,11 @@ def __init__(self,
87
89
# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
88
90
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
89
91
# @retried for more details.
90
- http_adapter = requests .adapters .HTTPAdapter (pool_connections = max_connections_per_pool or 20 ,
91
- pool_maxsize = max_connection_pools or 20 ,
92
- pool_block = pool_block )
92
+ http_adapter = requests .adapters .HTTPAdapter (
93
+ pool_connections = max_connections_per_pool or 20 ,
94
+ pool_maxsize = max_connection_pools or 20 ,
95
+ pool_block = pool_block ,
96
+ )
93
97
self ._session .mount ("https://" , http_adapter )
94
98
95
99
# Default to 60 seconds
@@ -110,7 +114,7 @@ def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]:
110
114
# See: https://github.com/databricks/databricks-sdk-py/issues/142
111
115
if query is None :
112
116
return None
113
- with_fixed_bools = {k : v if type (v ) != bool else (' true' if v else ' false' ) for k , v in query .items ()}
117
+ with_fixed_bools = {k : v if type (v ) != bool else (" true" if v else " false" ) for k , v in query .items ()}
114
118
115
119
# Query parameters may be nested, e.g.
116
120
# {'filter_by': {'user_ids': [123, 456]}}
@@ -140,30 +144,34 @@ def _is_seekable_stream(data) -> bool:
140
144
return False
141
145
return data .seekable ()
142
146
143
- def do (self ,
144
- method : str ,
145
- url : str ,
146
- query : dict = None ,
147
- headers : dict = None ,
148
- body : dict = None ,
149
- raw : bool = False ,
150
- files = None ,
151
- data = None ,
152
- auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ,
153
- response_headers : List [str ] = None ) -> Union [dict , list , BinaryIO ]:
147
+ def do (
148
+ self ,
149
+ method : str ,
150
+ url : str ,
151
+ query : dict = None ,
152
+ headers : dict = None ,
153
+ body : dict = None ,
154
+ raw : bool = False ,
155
+ files = None ,
156
+ data = None ,
157
+ auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ,
158
+ response_headers : List [str ] = None ,
159
+ ) -> Union [dict , list , BinaryIO ]:
154
160
if headers is None :
155
161
headers = {}
156
- headers [' User-Agent' ] = self ._user_agent_base
162
+ headers [" User-Agent" ] = self ._user_agent_base
157
163
158
164
# Wrap strings and bytes in a seekable stream so that we can rewind them.
159
165
if isinstance (data , (str , bytes )):
160
- data = io .BytesIO (data .encode (' utf-8' ) if isinstance (data , str ) else data )
166
+ data = io .BytesIO (data .encode (" utf-8" ) if isinstance (data , str ) else data )
161
167
162
168
if not data :
163
169
# The request is not a stream.
164
- call = retried (timeout = timedelta (seconds = self ._retry_timeout_seconds ),
165
- is_retryable = self ._is_retryable ,
166
- clock = self ._clock )(self ._perform )
170
+ call = retried (
171
+ timeout = timedelta (seconds = self ._retry_timeout_seconds ),
172
+ is_retryable = self ._is_retryable ,
173
+ clock = self ._clock ,
174
+ )(self ._perform )
167
175
elif self ._is_seekable_stream (data ):
168
176
# Keep track of the initial position of the stream so that we can rewind to it
169
177
# if we need to retry the request.
@@ -173,25 +181,29 @@ def rewind():
173
181
logger .debug (f"Rewinding input data to offset { initial_data_position } before retry" )
174
182
data .seek (initial_data_position )
175
183
176
- call = retried (timeout = timedelta (seconds = self ._retry_timeout_seconds ),
177
- is_retryable = self ._is_retryable ,
178
- clock = self ._clock ,
179
- before_retry = rewind )(self ._perform )
184
+ call = retried (
185
+ timeout = timedelta (seconds = self ._retry_timeout_seconds ),
186
+ is_retryable = self ._is_retryable ,
187
+ clock = self ._clock ,
188
+ before_retry = rewind ,
189
+ )(self ._perform )
180
190
else :
181
191
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
182
192
# where the retry doesn't re-read already read data from the stream.
183
193
logger .debug (f"Retry disabled for non-seekable stream: type={ type (data )} " )
184
194
call = self ._perform
185
195
186
- response = call (method ,
187
- url ,
188
- query = query ,
189
- headers = headers ,
190
- body = body ,
191
- raw = raw ,
192
- files = files ,
193
- data = data ,
194
- auth = auth )
196
+ response = call (
197
+ method ,
198
+ url ,
199
+ query = query ,
200
+ headers = headers ,
201
+ body = body ,
202
+ raw = raw ,
203
+ files = files ,
204
+ data = data ,
205
+ auth = auth ,
206
+ )
195
207
196
208
resp = dict ()
197
209
for header in response_headers if response_headers else []:
@@ -220,6 +232,7 @@ def _is_retryable(err: BaseException) -> Optional[str]:
220
232
# and Databricks SDK for Go retries
221
233
# (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go)
222
234
from urllib3 .exceptions import ProxyError
235
+
223
236
if isinstance (err , ProxyError ):
224
237
err = err .original_error
225
238
if isinstance (err , requests .ConnectionError ):
@@ -230,48 +243,55 @@ def _is_retryable(err: BaseException) -> Optional[str]:
230
243
#
231
244
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
232
245
# will bubble up the original exception in case we reach max retries.
233
- return f' cannot connect'
246
+ return f" cannot connect"
234
247
if isinstance (err , requests .Timeout ):
235
248
# corresponds to `TLS handshake timeout` and `i/o timeout` in Go.
236
249
#
237
250
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
238
251
# will bubble up the original exception in case we reach max retries.
239
- return f' timeout'
252
+ return f" timeout"
240
253
if isinstance (err , DatabricksError ):
241
254
message = str (err )
242
255
transient_error_string_matches = [
243
256
"com.databricks.backend.manager.util.UnknownWorkerEnvironmentException" ,
244
- "does not have any associated worker environments" , "There is no worker environment with id" ,
245
- "Unknown worker environment" , "ClusterNotReadyException" , "Unexpected error" ,
257
+ "does not have any associated worker environments" ,
258
+ "There is no worker environment with id" ,
259
+ "Unknown worker environment" ,
260
+ "ClusterNotReadyException" ,
261
+ "Unexpected error" ,
246
262
"Please try again later or try a faster operation." ,
247
263
"RPC token bucket limit has been exceeded" ,
248
264
]
249
265
for substring in transient_error_string_matches :
250
266
if substring not in message :
251
267
continue
252
- return f' matched { substring } '
268
+ return f" matched { substring } "
253
269
return None
254
270
255
- def _perform (self ,
256
- method : str ,
257
- url : str ,
258
- query : dict = None ,
259
- headers : dict = None ,
260
- body : dict = None ,
261
- raw : bool = False ,
262
- files = None ,
263
- data = None ,
264
- auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ):
265
- response = self ._session .request (method ,
266
- url ,
267
- params = self ._fix_query_string (query ),
268
- json = body ,
269
- headers = headers ,
270
- files = files ,
271
- data = data ,
272
- auth = auth ,
273
- stream = raw ,
274
- timeout = self ._http_timeout_seconds )
271
+ def _perform (
272
+ self ,
273
+ method : str ,
274
+ url : str ,
275
+ query : dict = None ,
276
+ headers : dict = None ,
277
+ body : dict = None ,
278
+ raw : bool = False ,
279
+ files = None ,
280
+ data = None ,
281
+ auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ,
282
+ ):
283
+ response = self ._session .request (
284
+ method ,
285
+ url ,
286
+ params = self ._fix_query_string (query ),
287
+ json = body ,
288
+ headers = headers ,
289
+ files = files ,
290
+ data = data ,
291
+ auth = auth ,
292
+ stream = raw ,
293
+ timeout = self ._http_timeout_seconds ,
294
+ )
275
295
self ._record_request_log (response , raw = raw or data is not None or files is not None )
276
296
error = self ._error_parser .get_api_error (response )
277
297
if error is not None :
@@ -312,7 +332,7 @@ def flush(self) -> int:
312
332
313
333
def __init__ (self , response : _RawResponse , chunk_size : Union [int , None ] = None ):
314
334
self ._response = response
315
- self ._buffer = b''
335
+ self ._buffer = b""
316
336
self ._content = None
317
337
self ._chunk_size = chunk_size
318
338
@@ -338,14 +358,14 @@ def isatty(self) -> bool:
338
358
339
359
def read (self , n : int = - 1 ) -> bytes :
340
360
"""
341
- Read up to n bytes from the response stream. If n is negative, read
342
- until the end of the stream.
361
+ Read up to n bytes from the response stream. If n is negative, read
362
+ until the end of the stream.
343
363
"""
344
364
345
365
self ._open ()
346
366
read_everything = n < 0
347
367
remaining_bytes = n
348
- res = b''
368
+ res = b""
349
369
while remaining_bytes > 0 or read_everything :
350
370
if len (self ._buffer ) == 0 :
351
371
try :
@@ -395,8 +415,12 @@ def __next__(self) -> bytes:
395
415
def __iter__ (self ) -> Iterator [bytes ]:
396
416
return self ._content
397
417
398
- def __exit__ (self , t : Union [Type [BaseException ], None ], value : Union [BaseException , None ],
399
- traceback : Union [TracebackType , None ]) -> None :
418
+ def __exit__ (
419
+ self ,
420
+ t : Union [Type [BaseException ], None ],
421
+ value : Union [BaseException , None ],
422
+ traceback : Union [TracebackType , None ],
423
+ ) -> None :
400
424
self ._content = None
401
- self ._buffer = b''
425
+ self ._buffer = b""
402
426
self .close ()
0 commit comments