@@ -507,6 +507,43 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
507
507
)
508
508
509
509
510
+ class DummyDecisionTransformerInputGenerator (DummyTextInputGenerator ):
511
+ """
512
+ Generates dummy decision transformer inputs.
513
+ """
514
+
515
+ SUPPORTED_INPUT_NAMES = (
516
+ "states" ,
517
+ "actions" ,
518
+ "timesteps" ,
519
+ "returns_to_go" ,
520
+ "attention_mask" ,
521
+ )
522
+
523
+ def __init__ (self , * args , ** kwargs ):
524
+ super ().__init__ (* args , ** kwargs )
525
+ self .act_dim = self .normalized_config .config .act_dim
526
+ self .state_dim = self .normalized_config .config .state_dim
527
+ self .max_ep_len = self .normalized_config .config .max_ep_len
528
+
529
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
530
+ if input_name == "states" :
531
+ shape = [self .batch_size , self .sequence_length , self .state_dim ]
532
+ elif input_name == "actions" :
533
+ shape = [self .batch_size , self .sequence_length , self .act_dim ]
534
+ elif input_name == "rewards" :
535
+ shape = [self .batch_size , self .sequence_length , 1 ]
536
+ elif input_name == "returns_to_go" :
537
+ shape = [self .batch_size , self .sequence_length , 1 ]
538
+ elif input_name == "attention_mask" :
539
+ shape = [self .batch_size , self .sequence_length ]
540
+ elif input_name == "timesteps" :
541
+ shape = [self .batch_size , self .sequence_length ]
542
+ return self .random_int_tensor (shape = shape , max_value = self .max_ep_len , framework = framework , dtype = int_dtype )
543
+
544
+ return self .random_float_tensor (shape , min_value = - 2.0 , max_value = 2.0 , framework = framework , dtype = float_dtype )
545
+
546
+
510
547
class DummySeq2SeqDecoderTextInputGenerator (DummyDecoderTextInputGenerator ):
511
548
SUPPORTED_INPUT_NAMES = (
512
549
"decoder_input_ids" ,
0 commit comments