diff --git a/src/lobster/tokenization/_latent_generator_3d_coord_tokenizer.py b/src/lobster/tokenization/_latent_generator_3d_coord_tokenizer.py index 8bca74e..def466e 100644 --- a/src/lobster/tokenization/_latent_generator_3d_coord_tokenizer.py +++ b/src/lobster/tokenization/_latent_generator_3d_coord_tokenizer.py @@ -32,10 +32,9 @@ def _make_latent_generator_3d_coord_tokenizer() -> PreTrainedTokenizerFast: # WordLevel tokenizer tokenizer_model = WordLevel(LG_VOCAB, unk_token="") - #pretokenizers + # pretokenizers pre_tokenizer = pre_tokenizers.Sequence([WhitespaceSplit()]) - # bert style post processing post_processor = TemplateProcessing( single=" $A ", @@ -46,7 +45,7 @@ def _make_latent_generator_3d_coord_tokenizer() -> PreTrainedTokenizerFast: return make_pretrained_tokenizer_fast( tokenizer_model=tokenizer_model, post_processor=post_processor, - pre_tokenizer = pre_tokenizer, + pre_tokenizer=pre_tokenizer, eos_token="", unk_token="", pad_token="", @@ -71,6 +70,3 @@ def __init__(self): cls_token="", mask_token="", ) -if __name__ == "__main__": - tokenizer = _make_latent_generator_3d_coord_tokenizer() - tokenizer.save_pretrained("/Users/lisanzas/Research/Develop/lobster/src/lobster/assets/latent_generator_tokenizer") diff --git a/tests/lobster/tokenization/test__latent_generator_tokenizer.py b/tests/lobster/tokenization/test__latent_generator_tokenizer.py index 4deacad..142f3f1 100644 --- a/tests/lobster/tokenization/test__latent_generator_tokenizer.py +++ b/tests/lobster/tokenization/test__latent_generator_tokenizer.py @@ -1,7 +1,9 @@ +from lobster.tokenization._latent_generator_3d_coord_tokenizer import ( + LatentGenerator3DCoordTokenizerFast, + _make_latent_generator_3d_coord_tokenizer, +) from transformers import PreTrainedTokenizerFast -from lobster.tokenization._latent_generator_3d_coord_tokenizer import LatentGenerator3DCoordTokenizerFast, _make_latent_generator_3d_coord_tokenizer - def test__make_latent_generator_3d_coord_tokenizer(): tokenizer = _make_latent_generator_3d_coord_tokenizer() @@ -32,7 +34,10 @@ def test__make_latent_generator_3d_coord_tokenizer(): tokenized_output = tokenizer("GD FH DS FH AD GF FE CZ EK DS CQ") assert tokenized_output.input_ids == [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2] - assert tokenizer.decode(tokenized_output.input_ids) == " " + assert ( + tokenizer.decode(tokenized_output.input_ids) + == " " + ) tokenized_output = tokenizer("R A gd fh ds") assert tokenized_output.input_ids == [0, 3, 3, 191, 169, 128, 2] @@ -59,7 +64,10 @@ def test__init__(self): tokenized_output = tokenizer("GD FH DS FH AD GF FE CZ EK DS CQ") assert tokenized_output.input_ids == [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2] - assert tokenizer.decode(tokenized_output.input_ids) == " " + assert ( + tokenizer.decode(tokenized_output.input_ids) + == " " + ) tokenized_output = tokenizer("R A gd fh ds") assert tokenized_output.input_ids == [0, 3, 3, 191, 169, 128, 2] @@ -73,6 +81,7 @@ def test__init__(self): "mask_token": "", } + if __name__ == "__main__": test__make_latent_generator_3d_coord_tokenizer() - TestLatentGenerator3DCoordTokenizerFast().test__init__() \ No newline at end of file + TestLatentGenerator3DCoordTokenizerFast().test__init__()