@@ -9,7 +9,7 @@ use std::marker::PhantomData;
9
9
use std:: sync:: Arc ;
10
10
11
11
thread_local ! {
12
- static CURRENT_CONTEXT : RefCell <Context > = RefCell :: new( Context :: default ( ) ) ;
12
+ static CURRENT_CONTEXT : RefCell <ContextStack > = RefCell :: new( ContextStack :: default ( ) ) ;
13
13
}
14
14
15
15
/// An execution-scoped collection of values.
@@ -122,7 +122,7 @@ impl Context {
122
122
/// Note: This function will panic if you attempt to attach another context
123
123
/// while the current one is still borrowed.
124
124
pub fn map_current < T > ( f : impl FnOnce ( & Context ) -> T ) -> T {
125
- CURRENT_CONTEXT . with ( |cx| f ( & cx. borrow ( ) ) )
125
+ CURRENT_CONTEXT . with ( |cx| cx. borrow ( ) . map_current_cx ( f ) )
126
126
}
127
127
128
128
/// Returns a clone of the current thread's context with the given value.
@@ -298,12 +298,10 @@ impl Context {
298
298
/// assert_eq!(Context::current().get::<ValueA>(), None);
299
299
/// ```
300
300
pub fn attach ( self ) -> ContextGuard {
301
- let previous_cx = CURRENT_CONTEXT
302
- . try_with ( |current| current. replace ( self ) )
303
- . ok ( ) ;
301
+ let cx_id = CURRENT_CONTEXT . with ( |cx| cx. borrow_mut ( ) . push ( self ) ) ;
304
302
305
303
ContextGuard {
306
- previous_cx ,
304
+ cx_id ,
307
305
_marker : PhantomData ,
308
306
}
309
307
}
@@ -336,15 +334,16 @@ impl fmt::Debug for Context {
336
334
/// A guard that resets the current context to the prior context when dropped.
337
335
#[ allow( missing_debug_implementations) ]
338
336
pub struct ContextGuard {
339
- previous_cx : Option < Context > ,
337
+ cx_id : usize ,
340
338
// ensure this type is !Send as it relies on thread locals
341
339
_marker : PhantomData < * const ( ) > ,
342
340
}
343
341
344
342
impl Drop for ContextGuard {
345
343
fn drop ( & mut self ) {
346
- if let Some ( previous_cx) = self . previous_cx . take ( ) {
347
- let _ = CURRENT_CONTEXT . try_with ( |current| current. replace ( previous_cx) ) ;
344
+ let id = self . cx_id ;
345
+ if id > 0 {
346
+ CURRENT_CONTEXT . with ( |context_stack| context_stack. borrow_mut ( ) . pop_id ( id) ) ;
348
347
}
349
348
}
350
349
}
@@ -371,6 +370,75 @@ impl Hasher for IdHasher {
371
370
}
372
371
}
373
372
373
+ struct ContextStack {
374
+ current_cx : Context ,
375
+ current_id : usize ,
376
+ // TODO:ban wrap the whole id thing in its own type
377
+ id_count : usize ,
378
+ // TODO:ban wrap the the tuple in its own type
379
+ stack : Vec < Option < ( usize , Context ) > > ,
380
+ }
381
+
382
+ impl ContextStack {
383
+ #[ inline( always) ]
384
+ fn push ( & mut self , cx : Context ) -> usize {
385
+ self . id_count += 512 ; // TODO:ban clean up this
386
+ let next_id = self . stack . len ( ) + 1 + self . id_count ;
387
+ let current_cx = std:: mem:: replace ( & mut self . current_cx , cx) ;
388
+ self . stack . push ( Some ( ( self . current_id , current_cx) ) ) ;
389
+ self . current_id = next_id;
390
+ next_id
391
+ }
392
+
393
+ #[ inline( always) ]
394
+ fn pop_id ( & mut self , id : usize ) {
395
+ if id == 0 {
396
+ return ;
397
+ }
398
+ // Are we at the top of the stack?
399
+ if id == self . current_id {
400
+ // Shrink the stack if possible
401
+ while let Some ( None ) = self . stack . last ( ) {
402
+ self . stack . pop ( ) ;
403
+ }
404
+ // There is always the initial context at the bottom of the stack
405
+ if let Some ( Some ( ( next_id, next_cx) ) ) = self . stack . pop ( ) {
406
+ self . current_cx = next_cx;
407
+ self . current_id = next_id;
408
+ }
409
+ } else {
410
+ let pos = id & 511 ; // TODO:ban clean up this
411
+ if pos >= self . stack . len ( ) {
412
+ // This is an invalid id, ignore it
413
+ return ;
414
+ }
415
+ if let Some ( ( pos_id, _) ) = self . stack [ pos] {
416
+ // Is the correct id at this position?
417
+ if pos_id == id {
418
+ // Clear out this entry
419
+ self . stack [ pos] = None ;
420
+ }
421
+ }
422
+ }
423
+ }
424
+
425
+ #[ inline( always) ]
426
+ fn map_current_cx < T > ( & self , f : impl FnOnce ( & Context ) -> T ) -> T {
427
+ f ( & self . current_cx )
428
+ }
429
+ }
430
+
431
+ impl Default for ContextStack {
432
+ fn default ( ) -> Self {
433
+ ContextStack {
434
+ current_id : 0 ,
435
+ current_cx : Context :: default ( ) ,
436
+ id_count : 0 ,
437
+ stack : Vec :: with_capacity ( 64 ) ,
438
+ }
439
+ }
440
+ }
441
+
374
442
#[ cfg( test) ]
375
443
mod tests {
376
444
use super :: * ;
@@ -415,7 +483,6 @@ mod tests {
415
483
}
416
484
417
485
#[ test]
418
- #[ ignore = "overlapping contexts are not supported yet" ]
419
486
fn overlapping_contexts ( ) {
420
487
#[ derive( Debug , PartialEq ) ]
421
488
struct ValueA ( & ' static str ) ;
0 commit comments