Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras 3, Kaggle, CLI, Streaming #295

Draft
wants to merge 119 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
232a80c
feat(streaming): training, inference, gradient accumulation
nglehuy Jun 15, 2024
9ea9b87
fix: remove ds2 dropout on conv module
nglehuy Jun 18, 2024
7c21c1d
fix: add sync batch norm, remove wrong bn in ds2
nglehuy Jun 22, 2024
4e0e8f5
fix: only wrap tf.function in jit compile
nglehuy Jun 24, 2024
de38407
fix: use autograph do_not_convert for batchnorm sync to work
nglehuy Jun 24, 2024
43d6054
chore: config
nglehuy Jul 2, 2024
d55fd40
fix: update train/test step
nglehuy Jul 2, 2024
2f502bb
fix: nan to num
nglehuy Jul 2, 2024
7441c95
fix: update compute mask ds2
nglehuy Jul 2, 2024
f2241cf
fix: nan to num
nglehuy Jul 2, 2024
ebb6930
fix: ctc loss
nglehuy Jul 2, 2024
305ddab
fix: update train step
nglehuy Jul 3, 2024
8268afe
chore: config
nglehuy Jul 3, 2024
c9e4d38
chore: config
nglehuy Jul 3, 2024
9b46cbe
fix: ctc
nglehuy Jul 3, 2024
b5bfe92
fix: add custom batch norm to avoid tf.cond
nglehuy Jul 4, 2024
dc08e6a
fix: env utils
nglehuy Jul 4, 2024
cf37206
fix: log batch that cause invalid loss
nglehuy Jul 4, 2024
4596609
chore: buffer size
nglehuy Jul 4, 2024
ba9d6b2
fix: handle unknown dataset size with no metadata provided
nglehuy Jul 6, 2024
fe594ad
chore: add option use loss scale
nglehuy Jul 6, 2024
7aed458
fix: support log debug
nglehuy Jul 6, 2024
4330dec
fix: update gradient accumulation
nglehuy Jul 10, 2024
522f080
fix: ga
nglehuy Jul 14, 2024
90cabc2
feat: tf2.16 with keras 3
nglehuy Jul 14, 2024
8285f6d
fix: ga
nglehuy Jul 14, 2024
fd446d6
Merge branch 'tf2.16' into feat-streaming
nglehuy Jul 15, 2024
5037896
feat: fix layers, models to tf2.16 with keras 3
nglehuy Jul 15, 2024
a853eba
feat: update models to compatible with keras 3
nglehuy Jul 21, 2024
f8a7b91
fix: loss compute using add_loss, loss tracking
nglehuy Jul 28, 2024
03f0d60
fix: output shapes of models to log to summary
nglehuy Jul 28, 2024
8b0ed02
fix: contextnet
nglehuy Jul 28, 2024
77baaa5
fix: ds2
nglehuy Jul 28, 2024
3768878
fix: jasper
nglehuy Jul 28, 2024
f3cb239
fix: rnnt
nglehuy Jul 28, 2024
1d7e3a6
fix: transformer
nglehuy Jul 28, 2024
7be3eda
fix: update deps
nglehuy Jul 29, 2024
9b03b31
fix: requirements
nglehuy Aug 25, 2024
401180b
fix: super init
nglehuy Aug 25, 2024
786f5d4
fix: update regularizers
nglehuy Aug 25, 2024
a9d1733
fix: update regularizers
nglehuy Aug 25, 2024
689b366
fix: print shapes
nglehuy Nov 24, 2024
4e75c0f
fix: conformer ctc
nglehuy Nov 24, 2024
82d91c8
fix: add ctc tpu impl
nglehuy Nov 25, 2024
c915be3
fix: ctc tpu impl
nglehuy Nov 25, 2024
dda33b7
fix: save weights, tpu connect
nglehuy Nov 28, 2024
c667984
fix: save weights, tpu connect
nglehuy Nov 28, 2024
dc77b84
fix: update req
nglehuy Dec 3, 2024
d455ae1
fix: update req
nglehuy Dec 3, 2024
33394a2
fix: update req
nglehuy Dec 3, 2024
35160ce
fix: update req
nglehuy Dec 3, 2024
6ffb3b8
fix: update req
nglehuy Dec 3, 2024
bb732a7
fix: update req
nglehuy Dec 4, 2024
9179425
fix: strategy scope
nglehuy Dec 4, 2024
ace3887
fix: requirements
nglehuy Dec 4, 2024
67a8470
fix: update savings
nglehuy Dec 7, 2024
9824819
feat: bundle scripts inside package
nglehuy Dec 29, 2024
05b068b
feat: introduce chunk-wise masking for mha layer
nglehuy Dec 31, 2024
1edf16a
feat: introduce chunk-wise masking to conformer & transformer
nglehuy Dec 31, 2024
6338f55
chore: update install script
nglehuy Jan 1, 2025
e844f77
chore: add conformer small streaming
nglehuy Jan 1, 2025
0543c31
chore: add conformer small streaming
nglehuy Jan 1, 2025
a2e2022
fix: use history size instead of memory length
nglehuy Jan 1, 2025
4824929
chore: update logging
nglehuy Jan 1, 2025
eca664c
fix: streaming masking mha
nglehuy Jan 1, 2025
a302962
fix: conformer ctc configs
nglehuy Jan 7, 2025
ab83d87
feat: add kaggle backup and restore callback
nglehuy Jan 11, 2025
52de4c0
fix: support flash attention, update deps
nglehuy Jan 12, 2025
aaa06a5
chore: add conformer-ctc-small-streaming-kaggle
nglehuy Jan 12, 2025
2fd4f2b
fix: restore from kaggle model
nglehuy Jan 12, 2025
6e5c3b9
fix: restore from kaggle model
nglehuy Jan 12, 2025
f62fdf9
fix: ignore backup kaggle when nan loss occurs
nglehuy Jan 12, 2025
2947160
fix: only use tqdm when needed
nglehuy Jan 12, 2025
c532d5c
fix: deps
nglehuy Jan 23, 2025
0e5e826
fix: support static shape
nglehuy Jan 24, 2025
eb55a4b
fix: mha streaming mask
nglehuy Jan 24, 2025
12fbb85
fix: feature extraction mixed precision, configs
nglehuy Jan 25, 2025
2100d75
fix: expose relmha_causal, flash attention
nglehuy Jan 25, 2025
a5d1e84
fix: allow ctc to force use native tf impl
nglehuy Jan 25, 2025
5ff3163
chore: list devices
nglehuy Jan 25, 2025
3dbab33
fix: attention mask
nglehuy Feb 9, 2025
1545543
fix: general layers to show outputshape, invalid loss show outputs
nglehuy Feb 15, 2025
5f784b7
fix: models configs
nglehuy Feb 15, 2025
70ac41e
fix: config streaming
nglehuy Feb 20, 2025
eebc361
fix: configs
nglehuy Feb 23, 2025
6b0bec4
fix: configs
nglehuy Feb 23, 2025
2ac8e7f
Merge branch 'main' into feat-streaming
nglehuy Mar 9, 2025
7dcd145
fix: streaming masking mha
nglehuy Mar 9, 2025
e209040
fix: streaming masking mha
nglehuy Mar 9, 2025
5649fdd
fix: update mha attention mask
nglehuy Mar 13, 2025
26c4a5f
feat: add support for layer norm in conformer conv module
nglehuy Mar 13, 2025
4f77a52
chore: update configs
nglehuy Mar 13, 2025
c3ab865
fix: feature extraction layer dtype tf.float32 to ensure loss converg…
nglehuy Mar 17, 2025
778c1a2
fix: ctc loss tpu - case logits to float32
nglehuy Mar 17, 2025
e68ceee
fix: use auto mask
nglehuy Mar 19, 2025
f1e2a88
fix: pad logits length to label length
nglehuy Mar 20, 2025
f1a0ed6
fix: ctc loss tpu
nglehuy Mar 20, 2025
6f7f246
chore: config
nglehuy Mar 21, 2025
d538e69
fix: disable bias/activity regularizer as not needed
nglehuy Mar 21, 2025
7611ff8
chore: config
nglehuy Mar 23, 2025
0c8e7c1
chore: setup mxp
nglehuy Mar 24, 2025
56d2afa
chore: setup mxp
nglehuy Mar 24, 2025
454163c
fix: small kaggle
nglehuy Mar 25, 2025
a333bfc
chore: transformer-ctc streaming
nglehuy Mar 25, 2025
cf435a3
chore: config
nglehuy Mar 27, 2025
c35af45
fix: ctc-tpu clean label
nglehuy Mar 30, 2025
0556481
chore: configs
nglehuy Mar 30, 2025
3e88f65
chore: configs
nglehuy Mar 30, 2025
111f3ac
chore: configs
nglehuy Mar 30, 2025
8ee9813
fix: train step
nglehuy Mar 30, 2025
dde7760
fix: apply ga loss division before loss scaling
nglehuy Mar 30, 2025
aade071
fix: update train function with ga steps
nglehuy Mar 30, 2025
dc0c304
fix: update train step ga
nglehuy Mar 30, 2025
d541928
chore: configs
nglehuy Mar 30, 2025
91a39a2
chore: configs
nglehuy Mar 30, 2025
2a40da6
chore: configs
nglehuy Mar 30, 2025
8bcf0f3
chore: configs
nglehuy Mar 30, 2025
de58fed
fix: rnn kwargs
nglehuy Mar 30, 2025
a05494a
chore: update
nglehuy Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: ga
nglehuy committed Jul 14, 2024

Verified

This commit was signed with the committer’s verified signature.
Desvelao Antonio
commit 522f0802a8ea7a73dfcc2bfc0b22c89ef6c44967
6 changes: 3 additions & 3 deletions examples/configs/librispeech/data.yml.j2
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ data_config:
tfrecords_shards: 32
shuffle: True
cache: False
buffer_size: 100
buffer_size: 1000
drop_remainder: True
stage: train
metadata: {{metadata}}
@@ -24,9 +24,9 @@ data_config:
- {{datadir}}/dev-other/transcripts.tsv
tfrecords_dir: {{datadir}}/tfrecords
tfrecords_shards: 2
shuffle: True
shuffle: False
cache: False
buffer_size: 100
buffer_size: 1000
drop_remainder: True
stage: eval
metadata: {{metadata}}
8 changes: 3 additions & 5 deletions examples/train.py
Original file line number Diff line number Diff line change
@@ -36,11 +36,12 @@ def main(
jit_compile: bool = False,
ga_steps: int = None,
verbose: int = 1,
tpu_address: str = None,
repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")),
):
keras.backend.clear_session()
env_util.setup_seed()
strategy = env_util.setup_strategy(devices)
strategy = env_util.setup_strategy(devices, tpu_address=tpu_address)
env_util.setup_mxp(mxp=mxp)

config = Config(config_path, training=True, repodir=repodir, datadir=datadir, modeldir=modeldir)
@@ -68,10 +69,7 @@ def main(
ga_steps = ga_steps or config.learning_config.ga_steps or 1

train_data_loader = train_dataset.create(train_batch_size, ga_steps=ga_steps, padded_shapes=padded_shapes)
if train_dataset.use_ga:
logger.info(f"train_data_loader.element_spec = {json.dumps(train_data_loader.element_spec.element_spec, indent=2, default=str)}")
else:
logger.info(f"train_data_loader.element_spec = {json.dumps(train_data_loader.element_spec, indent=2, default=str)}")
logger.info(f"train_data_loader.element_spec = {json.dumps(train_data_loader.element_spec, indent=2, default=str)}")

eval_data_loader = eval_dataset.create(eval_batch_size, padded_shapes=padded_shapes)
if eval_data_loader:
10 changes: 0 additions & 10 deletions tensorflow_asr/datasets.py
Original file line number Diff line number Diff line change
@@ -64,7 +64,6 @@
import logging
import os
from dataclasses import asdict, dataclass
from typing import List

import numpy as np
import tqdm
@@ -386,15 +385,6 @@ def process(

# only apply for training dataset, eval and test dataset should not use GA
if ga_steps > 1 and self.stage == "train":

def _key_fn(i, _):
return i // ga_steps

def _reduce_fn(_, ds):
elem = ds.map(lambda _, x: x)
return tf.data.Dataset.from_tensors(elem)

dataset = dataset.enumerate().group_by_window(key_func=_key_fn, reduce_func=_reduce_fn, window_size=ga_steps)
self.use_ga = True

# PREFETCH to improve speed of input length
72 changes: 49 additions & 23 deletions tensorflow_asr/models/base_model.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@

import importlib
import logging
import typing

import numpy as np

@@ -152,6 +151,7 @@ def compile(
if isinstance(ga_steps, int) and ga_steps > 1:
self.use_ga = True
self.ga = GradientAccumulator(ga_steps=ga_steps)
kwargs["steps_per_execution"] = 1
logger.info(f"Using gradient accumulation with accumulate steps = {ga_steps}")
else:
self.use_ga = False
@@ -203,30 +203,24 @@ def _train_step(self, data: schemas.TrainData):

return gradients

def train_step(self, data_list: typing.Union[schemas.TrainData, typing.Iterable[schemas.TrainData]]):
if not self.use_ga:
data = data_list
gradients = self._train_step(data)
else:
iterator = iter(data_list)
data = next(iterator)
gradients = self._train_step(data)

for _ in range(1, self.ga.total_steps):
try:
data = next(iterator)
except StopIteration:
break
per_ga_gradients = self._train_step(data)
gradients = self.ga.accumulate(gradients, per_ga_gradients)

def _apply_gradients(self, gradients):
if self.gradn is not None:
gradients = self.gradn(step=self.optimizer.iterations, gradients=gradients)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

def train_step(self, data):
gradients = self._train_step(data)
self._apply_gradients(gradients)
metrics = self.get_metrics_result()
return metrics

def train_step_ga(self, data, prev_gradients):
gradients = self._train_step(data)
if prev_gradients is not None:
gradients = self.ga.accumulate(prev_gradients, gradients)
metrics = self.get_metrics_result()
return metrics, gradients

def _test_step(self, data: schemas.TrainData):
x = data[0]
y, _ = data_util.set_length(data[1].labels, data[1].labels_length)
@@ -278,6 +272,18 @@ def one_step_on_data(data):
if not self.run_eagerly:
one_step_on_data = tf.function(one_step_on_data, reduce_retracing=True, jit_compile=self.jit_compile)

@tf.autograph.experimental.do_not_convert
def one_ga_step_on_data(data, prev_gradients):
"""Runs a single training step on a batch of data."""
outputs, gradients = self.train_step_ga(data, prev_gradients)
# Ensure counter is updated only if `train_step` succeeds.
with tf.control_dependencies(_minimum_control_deps(outputs)):
self._train_counter.assign_add(1)
return outputs, gradients

if not self.run_eagerly:
one_ga_step_on_data = tf.function(one_ga_step_on_data, reduce_retracing=True, jit_compile=self.jit_compile)

@tf.autograph.experimental.do_not_convert
def one_step_on_iterator(iterator):
"""Runs a single training step given a Dataset iterator."""
@@ -292,11 +298,31 @@ def one_step_on_iterator(iterator):

@tf.autograph.experimental.do_not_convert
def multi_step_on_iterator(iterator):
for _ in range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
for _ in range(self.steps_per_execution.numpy().item()):
outputs, data = one_step_on_iterator(iterator)
return outputs, data

if self.steps_per_execution > 1:
@tf.autograph.experimental.do_not_convert
def ga_step_in_iterator(iterator):
data = next(iterator)
outputs, gradients = self.distribute_strategy.run(one_ga_step_on_data, args=(data, None))
for _ in range(1, self.ga.total_steps):
try:
data = next(iterator)
outputs, gradients = self.distribute_strategy.run(one_ga_step_on_data, args=(data, gradients))
except StopIteration:
break
self.distribute_strategy.run(self._apply_gradients, args=(gradients,))
outputs = keras_util.reduce_per_replica(
outputs,
self.distribute_strategy,
reduction=self.distribute_reduction_method,
)
return outputs, data

if self.use_ga:
train_function = ga_step_in_iterator
elif self.steps_per_execution > 1:
train_function = multi_step_on_iterator
else:
train_function = one_step_on_iterator
@@ -347,7 +373,7 @@ def one_step_on_iterator(iterator):

@tf.autograph.experimental.do_not_convert
def multi_step_on_iterator(iterator):
for _ in range(self.steps_per_execution):
for _ in range(self.steps_per_execution.numpy().item()):
outputs = one_step_on_iterator(iterator)
return outputs

3 changes: 3 additions & 0 deletions tensorflow_asr/optimizers/accumulation.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,9 @@ def __init__(self, ga_steps, name="ga"):
def total_steps(self):
return self._ga_steps

def is_apply_step(self, step):
return tf.math.equal(step % self._ga_steps, 0)

def accumulate(self, gradients, per_ga_gradients):
"""Accumulates :obj:`gradients` on the current replica."""
with tf.name_scope(self.name):
5 changes: 1 addition & 4 deletions tensorflow_asr/utils/env_util.py
Original file line number Diff line number Diff line change
@@ -83,10 +83,7 @@ def setup_devices(
def setup_tpu(
tpu_address=None,
):
if tpu_address is None:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
else:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="grpc://" + tpu_address)
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
tf.tpu.experimental.initialize_tpu_system(resolver)
return tf.distribute.TPUStrategy(resolver)