Code for the paper Efficient LLM Inference using Dynamic Input Pruning and Cache-Aware Masking
While mobile devices provide ever more compute power, improvements in DRAM bandwidth are much slower. This is unfortunate for large language model (LLM) token generation, which is heavily memory-bound. Previous work has proposed to leverage natural dynamic activation sparsity in ReLU-activated LLMs to reduce effective DRAM bandwidth per token. However, more recent LLMs use SwiGLU instead of ReLU, which result in little inherent sparsity. While SwiGLU activations can be pruned based on magnitude, the resulting sparsity patterns are difficult to predict, rendering previous approaches ineffective. To circumvent this issue, our work introduces Dynamic Input Pruning (DIP): a predictor-free dynamic sparsification approach, which preserves accuracy with minimal fine-tuning. DIP can further use lightweight LoRA adapters to regain some performance lost during sparsification. Lastly, we describe a novel cache-aware masking strategy, which considers the cache state and activation magnitude to further increase cache hit rate, improving LLM token rate on mobile devices. DIP outperforms other methods in terms of accuracy, memory and throughput trade-offs across simulated hardware settings. On Phi-3-Medium, DIP achieves a 46% reduction in memory and 40% increase in throughput with < 0.1 loss in perplexity. We plan to open source the code for HW simulator, methods, and experiments in this paper.
First, let's define the path where we want the repository to be downloaded:
REPO_PATH=<path/to/repo>
Now we can clone the repository:
git clone git@github.com:Qualcomm-AI-research/<...>.git $REPO_PATH
cd $REPO_PATH
Next, create a virtual environment.
python3 -m venv env
source env/bin/activate
Make sure to have Python 3.8 (tested with Python 3.8) and ensure the latest version of pip (tested with version 24.3.1):
pip install --upgrade --no-deps pip
Finally, install the dependencies using pip:
pip install -r requirements.txt
Add the repo directory to the PYTHONPATH
.
export PYTHONPATH=$PYTHONPATH:$REPO_PATH
-
Download the required models and datasets from Huggingface.
-
phi-3-medium
: https://huggingface.co/microsoft/Phi-3-medium-4k-instruct -
phi-3-mini
: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct -
llama-v3-8B
: https://huggingface.co/meta-llama/Meta-Llama-3-8B -
mistral-v01-7B
: https://huggingface.co/mistralai/Mistral-7B-v0.1Download and store each model and tokenizer in the same folder using:
from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("<MODEL_ID>") model = AutoModelForCausalLM.from_pretrained("<MODEL_ID>") tokenizer.save_pretrained("<MODEL_PATH>") model.save_pretrained("<MODEL_PATH>")
in which
<MODEL_ID>
refers to the HuggingFace identifier (e.g.mistralai/Mistral-7B-v0.1
) and<MODEL_PATH>
to a local directory.
-
-
Similarly, datasets can be downloaded from Hugging Face:
from datasets import load_dataset dataset = load_dataset("wikitext", name="wikitext-2-raw-v1") dataset.save_to_disk("<WIKITEXT_PATH>")
-
Edit the configuration file by replacing all the
???
in thehardware
section with the paths for the models and datasets as described above. Modify the other paths and device specifications to customize the experiment setup. Alternatively, each path can be specified directly from CLI. E.g. the path tophi-3-mini
can be specified by addinghardware.paths.models.phi-3-mini=<PATH_TO_PHI_3_MINI>
to each command. -
(Optional) Predictive models and LoRA adapters require a preprocessed version of the SlimPajama dataset.
- First pre-process the dataset using the model tokenizer to create sequences of length > 1024 (see https://huggingface.co/docs/transformers/v4.17.0/en/tasks/language_modeling)
- Store the preprocess dataset into
.arrow
fragments using thesave_to_disk()
method to a location<SLIMPAJAMA_PATH>
.
-
(Optional) Set the environment variable
TOKENIZERS_PARALLELISM=false
to suppress warnings:export TOKENIZERS_PARALLELISM=false
The experiments are managed with hydra through the run_experiment.py script.
Each experiment can be run by specifying the evaluation procedure, model and desired sparsity. E.g. to run DIP on
phi-3-mini
with 0.5 MLP density on wikitext use:
python scripts/run_experiment.py \
dense_model=phi-3-mini \
masking_hooks=dip \
masking_hooks.keep=0.5 \
evaluation=perplexity
The sparsification strategies (masking_hooks
) include:
glu_pruning
: Pruning the MLP down layers based on the magnitude of the GLU activationsgate_pruning
/up_pruning
: Pruning the MLP down and up (or gate) layers based on the partial GLU activationspredictor
: Pruning the MLP up down and gate layers based on the sparsity prediction of a smaller network (as in DejaVu).cats
: Contextually-Aware Thresholding for Sparsity in Large Language Models as a per-layer threshold version ofgate-pruning
dip
: our Dynamic Input Pruning implementation.
The evaluation
options include:
perplexity
: Wikitext perplexity evaluationmmlu
: 5-shots MMLU evaluationppl_mmlu
: A sequential combination of both Wikitext perplexity and MMLU evaluation
To estimate the overall system throughput, a user can enable the hardware simulator with +hw_simulator=default
. Its configuration can be changed with:
hardware.dram.capacity
: DRAM capacity in GB.hardware.io_speed.flash
/hardware.io_speed.dram
: Read and Write speed in GB/s.hardware.processor
: Directly loads the specifications for a given Apple processor (e.g.:A18_Pro_APL1V07
), as reported in Wikipedia.hardware.cache_strategy
: Cache eviction strategy to be used in DRAM (e.g.:no_cache
,lfu
,lru
,belady
).cache_hooks
: Cache-aware sparsity method (optional, e.g.:weighting_current_cache
for DIP-CA).precision
: Defines the precision in bits per value for each layer type. This allows simulating throughput with arbitrarily quantized models.
To obtain a full list of the parameters for the experiment and available configurations you can use
python scripts/run_experiment.py --help
Refer to the Hydra documentation for more details on Hydra usage.
- contextual_sparsity: The
contextual_sparsity
package contains all the code and logic, specifically:- adapters: Code to define and train LoRA adapters
- data: Definition of data loading functions and utilities
- dense_models: Definition of dense model loading functions and utilities
- evaluation: Definition of the evaluation functions
- hw_simulator: Definition of the hardware simulator used to estimate the throughput and memory usage.
- mask: Definition of the logic used to mask activations during the forward pass
- masking_hooks: Mask instantiation and definition of various masking logics, including DIP, CATS, and DejaVU.
- nn: Definition of (simulated) sparse
nn.Linear
layers - scripts: Scripts used to instantiate the components and perform various tasks
- utils: Utility functions
- scripts: Entrypoint to run experiments and definition of all the hydra configurations.
- config:
.yaml
configuration files, which define all the components used in the experiments - run_experiment.py: Python script used to run all experiments and tasks.
- config:
- tests: Test scripts
Hereafter we report the commands used to reproduce the results reported in Table 1 of the paper for the phi-3-medium
model.
The commands for the other models are obtained by replacing dense_model=phi-3-medium
with the respective model name.
Running the experiments for most models requires a GPU with at least 80GB of VRAM. Experiments on phi-3-mini
that do not require training can run on a single GPU with 40GB of VRAM.
# Gate pruning
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=gate_pruning masking_hooks.keep=0.25 evaluation=ppl_mmlu
# Up pruning
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=up_pruning masking_hooks.keep=0.25 evaluation=ppl_mmlu
# CATS
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=cats masking_hooks.keep=0.25 evaluation=ppl_mmlu
# CATS+LoRA (Requires SlimPajama setup)
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=cats masking_hooks.keep=0.25 evaluation=ppl_mmlu +adapter=lora
# DejaVU (Requires SlimPajama setup)
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=predictor masking_hooks.keep=0.5 evaluation=ppl_mmlu
# DIP
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=dip masking_hooks.keep=0.5 evaluation=ppl_mmlu
# DIP+LoRA (Requires SlimPajama setup)
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=dip masking_hooks.keep=0.5 evaluation=ppl_mmlu +adapters=lora
Results reported in Table 2 can be reproduced with the following commands. All models in this study must be first quantized to INT4 using Blockwise Quantization [1] and their precision set accordingly in the simulator: precision.attention=4.125 precision.mlp=4.125
. For models other than Phi-3-Medium, different DRAM sizes should be set as specified in Table 2 by modifying hw_simulator.dram.capacity
For each experiment, we run a sweep over the sparsity parameter masking_hooks.keep
at 0.1
intervals between 0
and 1
, and then report the throughput at the selected perplexity operating point.
# GLU pruning
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=glu_pruning masking_hooks.keep=0.5 +hw_simulator=default hw_simulator.simulate_glu_pruning=True precision.attention=4.125 precision.mlp=4.125 hw_simulator.dram.capacity=4e9
# Up pruning
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=up_pruning masking_hooks.keep=0.5 +hw_simulator=default precision.attention=4.125 precision.mlp=4.125 hw_simulator.dram.capacity=4e9
# CATS
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=cats masking_hooks.keep=0.5 +hw_simulator=default precision.attention=4.125 precision.mlp=4.125 hw_simulator.dram.capacity=4e9
# DIP
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=dip masking_hooks.keep=0.5 +hw_simulator=default precision.attention=4.125 precision.mlp=4.125 hw_simulator.dram.capacity=4e9
# DIP-CA
python scripts/run_experiment.py dense_model=phi-3-medium masking_hooks=dip masking_hooks.keep=0.5 +hw_simulator=default cache_hooks=weighting_current_cache cache_hooks.kwargs.gamma=0.2 precision.attention=4.125 precision.mlp=4.125 hw_simulator.dram.capacity=4e9
Results for SparseGPT [2] are obtained following instructions in the SparseGPT repository.
[1] Frantar, E., Ashkboos, S., Hoefler, T., & Alistarh, D. (2022). Gptq: Accurate post-training quantization for generative pre-trained transformers. arXiv preprint arXiv:2210.17323.
[2] Frantar, E., & Alistarh, D. (2023). Sparsegpt: Massive language models can be accurately pruned in one-shot. In International Conference on Machine Learning (pp. 10323-10337). PMLR.
@article{
2024dip,
title={Efficient LLM Inference using Dynamic Input Pruning and Cache-Aware Masking},
author={Marco Federici and Davide Belli and Mart van Baalen and Amir Jalalirad and Andrii Skliar and Bence Major and Markus Nagel and Paul Whatmough},
booktitle={arXiv},
year={2024},
url={https://arxiv.org/abs/2412.01380}
}