Skip to content

Commit

Permalink
Jetstream by default (#118)
Browse files Browse the repository at this point in the history
* test(slots): add unit tests for slots for jetstream too

Implementation is slightly different, so a separate test is added.

* test(truncate): adapt test for jetstream too

* refactor(test): make tinyllama test work for Jetstream and Torch/XLA

Most tests work for both, except for the continuous batching one.
This allows to remove the old GPT2 based tests, that are quite slow and
do not use any sharding or KV cache, so they might not really be
representative of most relevant models on TGI.

* test(gpt2): remove old test

There are equivalent tests now on the TinyLlama model, that run faster,
use the KV cache and sharding.
The only test that does not have an equivalence is the continuous
batching one, but the test was not working for most other models, so I
prefer to remove it anyway, as having it passing was not representative
anyway of the current state.

* feat(tgi): Jetstream/Pytorch is now the default engine

Now that the engine is stable and tested, its engine is set as the
default one for TGI.

* review(test): refactor slot test to avoid repeating code

* feat(tests): use pytests markers to filter jetstream and torch xla tests

So far filtering was done using the name of the test. Now the selection
is done using a custom marker, that allows for clearer filtering.

* review(tests): skip test message clarification

* ci(torch xla): use JETSTREAM_PT_DISABLE env var in command line

For some reason the env var was not carried on (though Jetstream was
disabled anyway). Moving the variable to the command line invocation
will remove a warning in the logs.

* review(ci): fix JETSTREAM_PT_DISABLE env var usage again

* fix(tests): remove expected results from tests with do_sample

Some tests result change when operations are done in a slightly
different way. This has happened now with the torch xla tests, resulting
in different results on the CI.
To avoid this, now tests compare the obtained token and text is
different from the one obtained when running with greedy search.
  • Loading branch information
tengomucho authored Nov 27, 2024
1 parent e7474e0 commit 8c2c199
Show file tree
Hide file tree
Showing 18 changed files with 181 additions and 250 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,28 @@ jobs:
- name: Run TGI Jetstream Pytorch - Llama
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Llama"
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Llama"
- name: Run TGI Jetstream Pytorch - Gemma
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and gemma"
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and gemma"
- name: Run TGI Jetstream Pytorch - Mixtral greedy
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -k "slow and Mixtral and greedy"
pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Mixtral and greedy"
- name: Run TGI Jetstream Pytorch - Quantization Mixtral
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Mixtral"
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Mixtral"
- name: Run TGI Jetstream Pytorch - Quantization Llama-3 8B
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-8B"
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-8B"
- name: Run TGI Jetstream Pytorch - Quantization Llama 3 70B
run: |
python -m \
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -k "slow and Llama-3-70B"
pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-70B"
- name: Run TGI Jetstream Pytorch - Other tests
run: |
python -m \
pytest -sv text-generation-inference/tests --runslow -k "jetstream and not decode and not quant"
pytest -sv text-generation-inference/tests --runslow -m jetstream -k "not decode"
1 change: 1 addition & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
PJRT_DEVICE: TPU
HF_TOKEN: ${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }}
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
env:
PJRT_DEVICE: TPU
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-pytorch-xla-tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
env:
PJRT_DEVICE: TPU
HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface
JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ jetstream_requirements: test_installs
tgi_test_jetstream: test_installs jetstream_requirements tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream
python -m pytest -sv text-generation-inference/tests -m jetstream

tgi_test: test_installs tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
python -m pytest -sv text-generation-inference/tests
python -m pytest -sv text-generation-inference/tests -m torch_xla

tgi_docker_test: tpu-tgi
python -m pip install -r text-generation-inference/integration-tests/requirements.txt
Expand Down
6 changes: 4 additions & 2 deletions docs/source/howto/serving.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ curl localhost/generate_stream \
-H 'Content-Type: application/json'
```

### Using Jetstream Pytorch as backend
### Jetstream Pytorch and Pytorch XLA backends

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available.
If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable.

[Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. It is possible to use this engine by setting the `JETSTREAM_PT=1` environment variable.

When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable.

Expand Down
6 changes: 3 additions & 3 deletions optimum/tpu/jetstream_pt_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def jetstream_pt_available() -> bool:
"""Check if the necessary imports to use jetstream_pt are available.
"""
try:
# For now Jetstream Pytorch is opt-in, it can be enabled with an ENV variable.
jetstream_pt_enabled = os.environ.get("JETSTREAM_PT", False) == "1"
if not jetstream_pt_enabled:
# Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable.
jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1"
if jetstream_pt_disabled:
return False
# Torch XLA should not be imported before torch_xla2 to avoid conflicts.
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC
self._max_new_tokens = self._generation_config.max_new_tokens
# TODO: stop_sequences, ignore_eos_token

def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor = None, selector: TokenSelector = None):
"""Reset the slot for the next generation.
Args:
Expand Down
13 changes: 13 additions & 0 deletions text-generation-inference/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from optimum.tpu import jetstream_pt_available


# See https://stackoverflow.com/a/61193490/217945 for run_slow
def pytest_addoption(parser):
Expand Down Expand Up @@ -33,3 +35,14 @@ def quantization_jetstream_int8():
# Clean up
os.environ.clear()
os.environ.update(old_environ)


def pytest_runtest_setup(item):
marker_names = [marker.name for marker in item.own_markers]
jetstream_pt_enabled = jetstream_pt_available()
# Skip tests that require torch xla but not jetstream
if "torch_xla" in marker_names and "jetstream" not in marker_names:
if jetstream_pt_enabled:
pytest.skip("Jetstream is enabled: xla test will be skipped")
elif "jetstream" in marker_names and not jetstream_pt_enabled:
pytest.skip("Test requires Jetstream PyTorch to be enabled")
4 changes: 4 additions & 0 deletions text-generation-inference/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
markers =
jetstream: mark a test as a test that uses jetstream backend
torch_xla: mark a test as a test that uses torch_xla backend
4 changes: 4 additions & 0 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from decode_tests_utils import DecodeTestParams, decode_single_test


# All tests in this file are for torch xla
pytestmark = pytest.mark.torch_xla

@pytest.mark.parametrize("params",
[
DecodeTestParams(
Expand All @@ -21,6 +24,7 @@
def test_decode_single(params):
decode_single_test(params)


@pytest.mark.slow
@pytest.mark.parametrize("params",
[
Expand Down
7 changes: 2 additions & 5 deletions text-generation-inference/tests/test_decode_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available

# All tests in this file are for jetstream
pytestmark = pytest.mark.jetstream

@pytest.mark.slow
@pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"])
Expand Down Expand Up @@ -35,8 +36,6 @@
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)

Expand Down Expand Up @@ -64,7 +63,5 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
params.do_sample = do_sample
decode_single_test(params)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pytest
from decode_tests_utils import DecodeTestParams, decode_single_test

from optimum.tpu.jetstream_pt_support import jetstream_pt_available

# All tests in this file are for jetstream
pytestmark = pytest.mark.jetstream

@pytest.mark.parametrize("params",
[
Expand All @@ -22,8 +23,6 @@
ids=["gemma-2b", "TinyLLama-v0"],
)
def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
decode_single_test(params)


Expand All @@ -49,6 +48,4 @@ def test_decode_jetstream_quantization(quantization_jetstream_int8, params):
ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B"],
)
def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params):
if not jetstream_pt_available():
pytest.skip("Jetstream PyTorch is not available")
decode_single_test(params)
53 changes: 33 additions & 20 deletions text-generation-inference/tests/test_generator_slot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
import numpy as np
import pytest
import torch
from text_generation_server.pb.generate_pb2 import Request
from transformers import AutoTokenizer, GenerationConfig


TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "openai-community/gpt2"]


@pytest.fixture(params=TOKENIZERS)
def tokenizer(request):
t = AutoTokenizer.from_pretrained(request.param)
t.padding_side = "left"
t.pad_token_id = t.eos_token_id
return t


@pytest.mark.parametrize(
# Defining this global variable will parametrize all tests in this file
pytestmark = pytest.mark.parametrize(
"input_text, generated_text",
[
[
Expand All @@ -29,26 +21,31 @@ def tokenizer(request):
],
ids=["spaces", "chinese-utf8", "emojis"],
)
def test_decode_streaming(tokenizer, input_text, generated_text):
from text_generation_server.generator import Slot
# Note: device used is cpu to make it faster
slot = Slot(0, tokenizer, "cpu")


@pytest.fixture(params=TOKENIZERS)
def tokenizer(request):
t = AutoTokenizer.from_pretrained(request.param)
t.padding_side = "left"
t.pad_token_id = t.eos_token_id
return t


def _test_decode_streaming(slot, return_tensors, tokenizer, input_text, generated_text):
request = Request(id=0, inputs=input_text)
slot.assign(0, request, GenerationConfig())
assert slot.cached_text == input_text

inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="pt")
inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors=return_tensors)
input_ids = inputs["input_ids"][0]
attention_mask = inputs["attention_mask"][0]
generated_tokens = tokenizer(generated_text, add_special_tokens=False)["input_ids"]

# We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)
all_input_ids = torch.cat([input_ids, torch.tensor(generated_tokens)])
all_input_ids = np.concatenate([input_ids, generated_tokens])
full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)
regenerated_text = full_text[len(input_text) :]

# Initialize the slot with the inputs
slot.reset(input_ids, attention_mask, selector=None)
slot.reset(input_ids, selector=None)

assert slot.generated_tokens == 0

Expand All @@ -60,3 +57,19 @@ def test_decode_streaming(tokenizer, input_text, generated_text):
decoded_text += text

assert decoded_text == regenerated_text


@pytest.mark.jetstream
def test_decode_streaming_jetstream(tokenizer, input_text, generated_text):
from text_generation_server.jetstream_pt_support.generator import Slot

slot = Slot(0, tokenizer)
_test_decode_streaming(slot, "np", tokenizer, input_text, generated_text)

@pytest.mark.torch_xla
def test_decode_streaming(tokenizer, input_text, generated_text):
from text_generation_server.generator import Slot

# Note: device used is cpu to make it faster
slot = Slot(0, tokenizer, "cpu")
_test_decode_streaming(slot, "pt", tokenizer, input_text, generated_text)
Loading

0 comments on commit 8c2c199

Please sign in to comment.