Skip to content

Commit

Permalink
Merge branch 'structure_tokens' of https://github.com/prescient-desig…
Browse files Browse the repository at this point in the history
…n/lobster into structure_tokens
  • Loading branch information
Sidney Lisanza committed Mar 10, 2025
2 parents 28b8eed + 2d20c8f commit 06d87ce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ def _make_latent_generator_3d_coord_tokenizer() -> PreTrainedTokenizerFast:
# WordLevel tokenizer
tokenizer_model = WordLevel(LG_VOCAB, unk_token="<unk>")

#pretokenizers
# pretokenizers
pre_tokenizer = pre_tokenizers.Sequence([WhitespaceSplit()])


# bert style post processing
post_processor = TemplateProcessing(
single="<cls> $A <eos>",
Expand All @@ -46,7 +45,7 @@ def _make_latent_generator_3d_coord_tokenizer() -> PreTrainedTokenizerFast:
return make_pretrained_tokenizer_fast(
tokenizer_model=tokenizer_model,
post_processor=post_processor,
pre_tokenizer = pre_tokenizer,
pre_tokenizer=pre_tokenizer,
eos_token="<eos>",
unk_token="<unk>",
pad_token="<pad>",
Expand All @@ -71,6 +70,3 @@ def __init__(self):
cls_token="<cls>",
mask_token="<mask>",
)
if __name__ == "__main__":
tokenizer = _make_latent_generator_3d_coord_tokenizer()
tokenizer.save_pretrained("/Users/lisanzas/Research/Develop/lobster/src/lobster/assets/latent_generator_tokenizer")
19 changes: 14 additions & 5 deletions tests/lobster/tokenization/test__latent_generator_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from lobster.tokenization._latent_generator_3d_coord_tokenizer import (
LatentGenerator3DCoordTokenizerFast,
_make_latent_generator_3d_coord_tokenizer,
)
from transformers import PreTrainedTokenizerFast

from lobster.tokenization._latent_generator_3d_coord_tokenizer import LatentGenerator3DCoordTokenizerFast, _make_latent_generator_3d_coord_tokenizer


def test__make_latent_generator_3d_coord_tokenizer():
tokenizer = _make_latent_generator_3d_coord_tokenizer()
Expand Down Expand Up @@ -32,7 +34,10 @@ def test__make_latent_generator_3d_coord_tokenizer():
tokenized_output = tokenizer("GD FH DS FH AD GF FE CZ EK DS CQ")

assert tokenized_output.input_ids == [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2]
assert tokenizer.decode(tokenized_output.input_ids) == "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <eos>"
assert (
tokenizer.decode(tokenized_output.input_ids)
== "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <eos>"
)

tokenized_output = tokenizer("R A gd fh ds")
assert tokenized_output.input_ids == [0, 3, 3, 191, 169, 128, 2]
Expand All @@ -59,7 +64,10 @@ def test__init__(self):

tokenized_output = tokenizer("GD FH DS FH AD GF FE CZ EK DS CQ")
assert tokenized_output.input_ids == [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2]
assert tokenizer.decode(tokenized_output.input_ids) == "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <eos>"
assert (
tokenizer.decode(tokenized_output.input_ids)
== "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <eos>"
)

tokenized_output = tokenizer("R A gd fh ds")
assert tokenized_output.input_ids == [0, 3, 3, 191, 169, 128, 2]
Expand All @@ -73,6 +81,7 @@ def test__init__(self):
"mask_token": "<mask>",
}


if __name__ == "__main__":
test__make_latent_generator_3d_coord_tokenizer()
TestLatentGenerator3DCoordTokenizerFast().test__init__()
TestLatentGenerator3DCoordTokenizerFast().test__init__()

0 comments on commit 06d87ce

Please sign in to comment.