Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras 3, Kaggle, CLI, Streaming #295

Draft
wants to merge 119 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
232a80c
feat(streaming): training, inference, gradient accumulation
nglehuy Jun 15, 2024
9ea9b87
fix: remove ds2 dropout on conv module
nglehuy Jun 18, 2024
7c21c1d
fix: add sync batch norm, remove wrong bn in ds2
nglehuy Jun 22, 2024
4e0e8f5
fix: only wrap tf.function in jit compile
nglehuy Jun 24, 2024
de38407
fix: use autograph do_not_convert for batchnorm sync to work
nglehuy Jun 24, 2024
43d6054
chore: config
nglehuy Jul 2, 2024
d55fd40
fix: update train/test step
nglehuy Jul 2, 2024
2f502bb
fix: nan to num
nglehuy Jul 2, 2024
7441c95
fix: update compute mask ds2
nglehuy Jul 2, 2024
f2241cf
fix: nan to num
nglehuy Jul 2, 2024
ebb6930
fix: ctc loss
nglehuy Jul 2, 2024
305ddab
fix: update train step
nglehuy Jul 3, 2024
8268afe
chore: config
nglehuy Jul 3, 2024
c9e4d38
chore: config
nglehuy Jul 3, 2024
9b46cbe
fix: ctc
nglehuy Jul 3, 2024
b5bfe92
fix: add custom batch norm to avoid tf.cond
nglehuy Jul 4, 2024
dc08e6a
fix: env utils
nglehuy Jul 4, 2024
cf37206
fix: log batch that cause invalid loss
nglehuy Jul 4, 2024
4596609
chore: buffer size
nglehuy Jul 4, 2024
ba9d6b2
fix: handle unknown dataset size with no metadata provided
nglehuy Jul 6, 2024
fe594ad
chore: add option use loss scale
nglehuy Jul 6, 2024
7aed458
fix: support log debug
nglehuy Jul 6, 2024
4330dec
fix: update gradient accumulation
nglehuy Jul 10, 2024
522f080
fix: ga
nglehuy Jul 14, 2024
90cabc2
feat: tf2.16 with keras 3
nglehuy Jul 14, 2024
8285f6d
fix: ga
nglehuy Jul 14, 2024
fd446d6
Merge branch 'tf2.16' into feat-streaming
nglehuy Jul 15, 2024
5037896
feat: fix layers, models to tf2.16 with keras 3
nglehuy Jul 15, 2024
a853eba
feat: update models to compatible with keras 3
nglehuy Jul 21, 2024
f8a7b91
fix: loss compute using add_loss, loss tracking
nglehuy Jul 28, 2024
03f0d60
fix: output shapes of models to log to summary
nglehuy Jul 28, 2024
8b0ed02
fix: contextnet
nglehuy Jul 28, 2024
77baaa5
fix: ds2
nglehuy Jul 28, 2024
3768878
fix: jasper
nglehuy Jul 28, 2024
f3cb239
fix: rnnt
nglehuy Jul 28, 2024
1d7e3a6
fix: transformer
nglehuy Jul 28, 2024
7be3eda
fix: update deps
nglehuy Jul 29, 2024
9b03b31
fix: requirements
nglehuy Aug 25, 2024
401180b
fix: super init
nglehuy Aug 25, 2024
786f5d4
fix: update regularizers
nglehuy Aug 25, 2024
a9d1733
fix: update regularizers
nglehuy Aug 25, 2024
689b366
fix: print shapes
nglehuy Nov 24, 2024
4e75c0f
fix: conformer ctc
nglehuy Nov 24, 2024
82d91c8
fix: add ctc tpu impl
nglehuy Nov 25, 2024
c915be3
fix: ctc tpu impl
nglehuy Nov 25, 2024
dda33b7
fix: save weights, tpu connect
nglehuy Nov 28, 2024
c667984
fix: save weights, tpu connect
nglehuy Nov 28, 2024
dc77b84
fix: update req
nglehuy Dec 3, 2024
d455ae1
fix: update req
nglehuy Dec 3, 2024
33394a2
fix: update req
nglehuy Dec 3, 2024
35160ce
fix: update req
nglehuy Dec 3, 2024
6ffb3b8
fix: update req
nglehuy Dec 3, 2024
bb732a7
fix: update req
nglehuy Dec 4, 2024
9179425
fix: strategy scope
nglehuy Dec 4, 2024
ace3887
fix: requirements
nglehuy Dec 4, 2024
67a8470
fix: update savings
nglehuy Dec 7, 2024
9824819
feat: bundle scripts inside package
nglehuy Dec 29, 2024
05b068b
feat: introduce chunk-wise masking for mha layer
nglehuy Dec 31, 2024
1edf16a
feat: introduce chunk-wise masking to conformer & transformer
nglehuy Dec 31, 2024
6338f55
chore: update install script
nglehuy Jan 1, 2025
e844f77
chore: add conformer small streaming
nglehuy Jan 1, 2025
0543c31
chore: add conformer small streaming
nglehuy Jan 1, 2025
a2e2022
fix: use history size instead of memory length
nglehuy Jan 1, 2025
4824929
chore: update logging
nglehuy Jan 1, 2025
eca664c
fix: streaming masking mha
nglehuy Jan 1, 2025
a302962
fix: conformer ctc configs
nglehuy Jan 7, 2025
ab83d87
feat: add kaggle backup and restore callback
nglehuy Jan 11, 2025
52de4c0
fix: support flash attention, update deps
nglehuy Jan 12, 2025
aaa06a5
chore: add conformer-ctc-small-streaming-kaggle
nglehuy Jan 12, 2025
2fd4f2b
fix: restore from kaggle model
nglehuy Jan 12, 2025
6e5c3b9
fix: restore from kaggle model
nglehuy Jan 12, 2025
f62fdf9
fix: ignore backup kaggle when nan loss occurs
nglehuy Jan 12, 2025
2947160
fix: only use tqdm when needed
nglehuy Jan 12, 2025
c532d5c
fix: deps
nglehuy Jan 23, 2025
0e5e826
fix: support static shape
nglehuy Jan 24, 2025
eb55a4b
fix: mha streaming mask
nglehuy Jan 24, 2025
12fbb85
fix: feature extraction mixed precision, configs
nglehuy Jan 25, 2025
2100d75
fix: expose relmha_causal, flash attention
nglehuy Jan 25, 2025
a5d1e84
fix: allow ctc to force use native tf impl
nglehuy Jan 25, 2025
5ff3163
chore: list devices
nglehuy Jan 25, 2025
3dbab33
fix: attention mask
nglehuy Feb 9, 2025
1545543
fix: general layers to show outputshape, invalid loss show outputs
nglehuy Feb 15, 2025
5f784b7
fix: models configs
nglehuy Feb 15, 2025
70ac41e
fix: config streaming
nglehuy Feb 20, 2025
eebc361
fix: configs
nglehuy Feb 23, 2025
6b0bec4
fix: configs
nglehuy Feb 23, 2025
2ac8e7f
Merge branch 'main' into feat-streaming
nglehuy Mar 9, 2025
7dcd145
fix: streaming masking mha
nglehuy Mar 9, 2025
e209040
fix: streaming masking mha
nglehuy Mar 9, 2025
5649fdd
fix: update mha attention mask
nglehuy Mar 13, 2025
26c4a5f
feat: add support for layer norm in conformer conv module
nglehuy Mar 13, 2025
4f77a52
chore: update configs
nglehuy Mar 13, 2025
c3ab865
fix: feature extraction layer dtype tf.float32 to ensure loss converg…
nglehuy Mar 17, 2025
778c1a2
fix: ctc loss tpu - case logits to float32
nglehuy Mar 17, 2025
e68ceee
fix: use auto mask
nglehuy Mar 19, 2025
f1e2a88
fix: pad logits length to label length
nglehuy Mar 20, 2025
f1a0ed6
fix: ctc loss tpu
nglehuy Mar 20, 2025
6f7f246
chore: config
nglehuy Mar 21, 2025
d538e69
fix: disable bias/activity regularizer as not needed
nglehuy Mar 21, 2025
7611ff8
chore: config
nglehuy Mar 23, 2025
0c8e7c1
chore: setup mxp
nglehuy Mar 24, 2025
56d2afa
chore: setup mxp
nglehuy Mar 24, 2025
454163c
fix: small kaggle
nglehuy Mar 25, 2025
a333bfc
chore: transformer-ctc streaming
nglehuy Mar 25, 2025
cf435a3
chore: config
nglehuy Mar 27, 2025
c35af45
fix: ctc-tpu clean label
nglehuy Mar 30, 2025
0556481
chore: configs
nglehuy Mar 30, 2025
3e88f65
chore: configs
nglehuy Mar 30, 2025
111f3ac
chore: configs
nglehuy Mar 30, 2025
8ee9813
fix: train step
nglehuy Mar 30, 2025
dde7760
fix: apply ga loss division before loss scaling
nglehuy Mar 30, 2025
aade071
fix: update train function with ga steps
nglehuy Mar 30, 2025
dc0c304
fix: update train step ga
nglehuy Mar 30, 2025
d541928
chore: configs
nglehuy Mar 30, 2025
91a39a2
chore: configs
nglehuy Mar 30, 2025
2a40da6
chore: configs
nglehuy Mar 30, 2025
8bcf0f3
chore: configs
nglehuy Mar 30, 2025
de58fed
fix: rnn kwargs
nglehuy Mar 30, 2025
a05494a
chore: update
nglehuy Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: update models to compatible with keras 3
nglehuy committed Jul 28, 2024
commit a853eba3bac704f58e2a383b65b4995933722459
4 changes: 2 additions & 2 deletions examples/models/ctc/conformer/char-small.yml.j2
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -91,7 +91,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/ctc/deepspeech2/base.yml.j2
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: ExponentialDecay
@@ -81,7 +81,7 @@ learning_config:
callbacks:
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/ctc/deepspeech2/uni.yml.j2
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -84,7 +84,7 @@ learning_config:
callbacks:
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/ctc/jasper/base.yml.j2
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate: 0.001
beta_1: 0.9
@@ -66,7 +66,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/ctc/transformer/base.yml.j2
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -83,7 +83,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
6 changes: 3 additions & 3 deletions examples/models/transducer/conformer/small-nfft.yml.j2
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ model_config:
encoder_interleave_relpe: True
encoder_use_attention_causal_mask: False
encoder_use_attention_auto_mask: True
encoder_mhsam_use_attention_bias: True
encoder_mhsam_use_attention_bias: False
encoder_kernel_size: 32
encoder_dropout: 0.1
encoder_padding: causal
@@ -78,7 +78,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -108,7 +108,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/transducer/conformer/small-no-decay.yml.j2
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -105,7 +105,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/transducer/conformer/small.yml.j2
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -107,7 +107,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/transducer/contextnet/small.yml.j2
Original file line number Diff line number Diff line change
@@ -226,7 +226,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -256,7 +256,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/transducer/rnnt/small.yml.j2
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -87,7 +87,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
4 changes: 2 additions & 2 deletions examples/models/transducer/rnnt/tiny.yml.j2
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ model_config:

learning_config:
optimizer_config:
class_name: Custom>Adam
class_name: Adam
config:
learning_rate:
class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
@@ -82,7 +82,7 @@ learning_config:
config: {}
- class_name: tensorflow_asr.callbacks>ModelCheckpoint
config:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
2 changes: 1 addition & 1 deletion examples/models/transducer/transformer/base.yml.j2
Original file line number Diff line number Diff line change
@@ -93,7 +93,7 @@ learning_config:
batch_size: 2
num_epochs: 300
checkpoint:
filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5
filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
3 changes: 3 additions & 0 deletions tensorflow_asr/losses/rnnt_loss.py
Original file line number Diff line number Diff line change
@@ -360,6 +360,9 @@ def rnnt_loss_tf(
orig_dtype = logits.dtype
if orig_dtype in (tf.float16, tf.bfloat16):
logits = tf.cast(logits, tf.float32)
logit_length = tf.cast(logit_length, tf.int32)
labels = tf.cast(labels, tf.int32)
label_length = tf.cast(label_length, tf.int32)

args = [logits, labels, label_length, logit_length]

14 changes: 4 additions & 10 deletions tensorflow_asr/models/base_layer.py
Original file line number Diff line number Diff line change
@@ -36,13 +36,7 @@ def call(self, inputs):
outputs = math_util.merge_two_last_dims(outputs)
return outputs, outputs_length

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],)
return output_shape, output_length_shape


@keras.utils.register_keras_serializable(package=__name__)
class Identity(Layer):
def call(self, inputs):
return inputs
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],)
# return output_shape, output_length_shape
149 changes: 92 additions & 57 deletions tensorflow_asr/models/encoders/conformer.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@

from tensorflow_asr import keras, tf
from tensorflow_asr.models.activations.glu import GLU
from tensorflow_asr.models.base_layer import Identity, Layer
from tensorflow_asr.models.base_layer import Layer
from tensorflow_asr.models.layers.convolution import DepthwiseConv1D
from tensorflow_asr.models.layers.multihead_attention import MultiHeadAttention, MultiHeadRelativeAttention
from tensorflow_asr.models.layers.positional_encoding import RelativeSinusoidalPositionalEncoding, SinusoidalPositionalEncoding
@@ -61,7 +61,7 @@ def __init__(
self.pre_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if norm_position == "pre"
else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype)
)
self.ffn1 = keras.layers.Dense(
units=scale_factor * input_dim,
@@ -83,7 +83,7 @@ def __init__(
self.post_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if norm_position == "post"
else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype)
)
self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual", dtype=self.dtype)

@@ -94,11 +94,11 @@ def call(self, inputs, training=False):
outputs = self.ffn2(outputs, training=training)
outputs = self.do2(outputs, training=training)
outputs = self.post_norm(outputs, training=training)
outputs = self.residual([inputs, outputs], training=training)
outputs = self.residual((inputs, outputs), training=training)
return outputs

def compute_output_shape(self, input_shape):
return input_shape
# def compute_output_shape(self, input_shape):
# return input_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -139,7 +139,7 @@ def __init__(
self.pre_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if norm_position == "pre"
else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype)
)
if mha_type == "relmha":
self.mha = MultiHeadRelativeAttention(
@@ -169,7 +169,7 @@ def __init__(
self.post_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if norm_position == "post"
else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype)
)
self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual", dtype=self.dtype)

@@ -179,30 +179,50 @@ def get_initial_state(self, batch_size: int):
def call(
self,
inputs,
content_attention_bias=None,
positional_attention_bias=None,
initial_state=None,
training=False,
attention_mask=None,
use_causal_mask=False,
use_auto_mask=True,
return_states=False,
):
_inputs, relative_position_encoding, content_attention_bias, positional_attention_bias = inputs
_inputs, relative_position_encoding = inputs
outputs = self.pre_norm(_inputs, training=training)
outputs, states = self.mha(
[outputs, outputs, outputs, relative_position_encoding, content_attention_bias, positional_attention_bias],
outputs, *states = self.mha(
[outputs, outputs, outputs, relative_position_encoding],
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
initial_state=initial_state,
training=training,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
use_auto_mask=use_auto_mask,
return_states=return_states,
)
outputs = self.do(outputs, training=training)
outputs = self.post_norm(outputs, training=training)
outputs = self.residual([_inputs, outputs], training=training)
return outputs, states

def compute_output_shape(self, input_shape):
output_shape, *_ = input_shape
return output_shape
outputs = self.residual((_inputs, outputs), training=training)
if return_states:
return [outputs] + states
return [outputs]

# def compute_output_shape(self, input_shape):
# output_shape, *_ = input_shape
# return output_shape

# def compute_output_spec(
# self,
# inputs,
# initial_state=None,
# attention_mask=None,
# use_causal_mask=False,
# use_auto_mask=True,
# ):
# return self.mha.compute_output_spec(
# inputs, attention_mask=attention_mask, use_causal_mask=use_causal_mask, use_auto_mask=use_auto_mask, initial_state=initial_state
# )


@keras.utils.register_keras_serializable(package=__name__)
@@ -247,7 +267,7 @@ def __init__(
self.pre_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if norm_position == "pre"
else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype)
)
self.pw_conv_1 = keras.layers.Conv1D(
filters=scale_factor * input_dim,
@@ -304,7 +324,7 @@ def __init__(
self.post_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if norm_position == "post"
else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype)
)
self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual", dtype=self.dtype)

@@ -318,11 +338,11 @@ def call(self, inputs, training=False):
outputs = self.pw_conv_2(outputs, training=training)
outputs = self.do(outputs, training=training)
outputs = self.post_norm(outputs, training=training)
outputs = self.residual([inputs, outputs], training=training)
outputs = self.residual((inputs, outputs), training=training)
return outputs

def compute_output_shape(self, input_shape):
return input_shape
# def compute_output_shape(self, input_shape):
# return input_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -366,7 +386,7 @@ def __init__(
self.pre_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if block_norm_position == "pre"
else Identity(name="preiden" if block_norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="preiden" if block_norm_position == "none" else "iden", dtype=self.dtype)
)
self.ffm1 = FFModule(
input_dim=input_dim,
@@ -423,7 +443,7 @@ def __init__(
self.post_norm = (
keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype)
if block_norm_position == "post"
else Identity(name="postiden" if block_norm_position == "none" else "iden", dtype=self.dtype)
else keras.layers.Identity(name="postiden" if block_norm_position == "none" else "iden", dtype=self.dtype)
)

def get_initial_state(self, batch_size: int):
@@ -432,31 +452,39 @@ def get_initial_state(self, batch_size: int):
def call(
self,
inputs,
content_attention_bias=None,
positional_attention_bias=None,
initial_state=None,
training=False,
attention_mask=None,
use_causal_mask=False,
use_auto_mask=True,
return_states=False,
):
inputs, relative_position_encoding, content_attention_bias, positional_attention_bias = inputs
outputs = self.pre_norm(inputs, training=training)
_inputs, relative_position_encoding = inputs
outputs = self.pre_norm(_inputs, training=training)
outputs = self.ffm1(outputs, training=training)
outputs, states = self.mhsam(
[outputs, relative_position_encoding, content_attention_bias, positional_attention_bias],
outputs, *states = self.mhsam(
[outputs, relative_position_encoding],
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
initial_state=initial_state,
training=training,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
use_auto_mask=use_auto_mask,
return_states=return_states,
)
outputs = self.convm(outputs, training=training)
outputs = self.ffm2(outputs, training=training)
outputs = self.post_norm(outputs, training=training)
return outputs, states
if return_states:
return [outputs] + states
return [outputs]

def compute_output_shape(self, input_shape):
output_shape, *_ = input_shape
return output_shape
# def compute_output_shape(self, input_shape):
# output_shape, *_ = input_shape
# return output_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -578,29 +606,36 @@ def __init__(
else:
self.content_attention_bias, self.positional_attention_bias = None, None

def call(self, inputs, initial_state=None, training=False):
def call(
self,
inputs,
initial_state=None,
training=False,
return_states=False,
):
outputs, outputs_length = inputs
outputs, outputs_length = self.conv_subsampling([outputs, outputs_length], training=training)
outputs, outputs_length = self.conv_subsampling((outputs, outputs_length), training=training)
outputs = self.linear(outputs, training=training)
outputs = self.do(outputs, training=training)
outputs, relative_position_encoding = self.relpe([outputs, outputs_length], training=training)
outputs, relative_position_encoding = self.relpe((outputs, outputs_length), training=training)
states = None if self._memory_length is None else []
for i, cblock in enumerate(self.conformer_blocks):
outputs, _states = cblock(
[
outputs,
relative_position_encoding,
self.content_attention_bias,
self.positional_attention_bias,
],
outputs, *_states = cblock(
(outputs, relative_position_encoding),
content_attention_bias=self.content_attention_bias,
positional_attention_bias=self.positional_attention_bias,
initial_state=None if initial_state is None else initial_state[i],
training=training,
use_causal_mask=self._use_attention_causal_mask,
use_auto_mask=self._use_attention_auto_mask,
return_states=return_states,
)
if states is not None:
states.append(_states)
return outputs, outputs_length, states
if not states:
continue
states.extend(_states)
if return_states:
return outputs, outputs_length, states
return outputs, outputs_length

def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs):
"""
@@ -617,17 +652,17 @@ def call_next(self, features, features_length, previous_encoder_states, *args, *
Outputs, outputs_length, new_states
"""
with tf.name_scope(f"{self.name}_call_next"):
return self.call((features, features_length), initial_state=previous_encoder_states, training=False)
return self((features, features_length), initial_state=previous_encoder_states, training=False, return_states=True)

def compute_mask(self, inputs, mask=None):
return *self.conv_subsampling.compute_mask(inputs, mask=mask), None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape, output_length_shape = self.conv_subsampling.compute_output_shape((output_shape, output_length_shape))
output_shape = self.linear.compute_output_shape(output_shape)
output_shape, relative_position_encoding_shape = self.relpe.compute_output_shape((output_shape, output_length_shape))
output_shape = self.do.compute_output_shape(output_shape)
for cblock in self.conformer_blocks:
output_shape = cblock.compute_output_shape((output_shape, relative_position_encoding_shape, None, None))
return output_shape, output_length_shape
return self.conv_subsampling.compute_mask(inputs, mask=mask)

# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape, output_length_shape = self.conv_subsampling.compute_output_shape((output_shape, output_length_shape))
# output_shape = self.linear.compute_output_shape(output_shape)
# output_shape, relative_position_encoding_shape = self.relpe.compute_output_shape((output_shape, output_length_shape))
# output_shape = self.do.compute_output_shape(output_shape)
# for cblock in self.conformer_blocks:
# output_shape = cblock.compute_output_shape((output_shape, relative_position_encoding_shape, None, None))
# return output_shape, output_length_shape
32 changes: 16 additions & 16 deletions tensorflow_asr/models/encoders/deepspeech2.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.

from tensorflow_asr import keras, tf
from tensorflow_asr.models.base_layer import Identity, Layer, Reshape
from tensorflow_asr.models.base_layer import Layer, Reshape
from tensorflow_asr.models.layers.convolution import DepthwiseConv1D
from tensorflow_asr.utils import layer_util, math_util

@@ -121,12 +121,12 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape = self.conv.compute_output_shape(output_shape)
output_shape = self.bn.compute_output_shape(output_shape)
output_shape = self.act.compute_output_shape(output_shape)
return output_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape = self.conv.compute_output_shape(output_shape)
# output_shape = self.bn.compute_output_shape(output_shape)
# output_shape = self.act.compute_output_shape(output_shape)
# return output_shape, output_length_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -148,7 +148,7 @@ def __init__(
assert conv_type in ("conv1d", "conv2d")
assert len(kernels) == len(strides) == len(filters)

self.pre = Reshape(name="preprocess", dtype=self.dtype) if conv_type == "conv1d" else Identity(name="iden", dtype=self.dtype)
self.pre = Reshape(name="preprocess", dtype=self.dtype) if conv_type == "conv1d" else keras.layers.Identity(name="iden", dtype=self.dtype)

self.convs = []
self.time_reduction_factor = 1
@@ -169,7 +169,7 @@ def __init__(
self.convs.append(conv_block)
self.time_reduction_factor *= conv_block.time_reduction_factor

self.post = Reshape(name="postprocess", dtype=self.dtype) if conv_type == "conv2d" else Identity(name="iden", dtype=self.dtype)
self.post = Reshape(name="postprocess", dtype=self.dtype) if conv_type == "conv2d" else keras.layers.Identity(name="iden", dtype=self.dtype)

def call(self, inputs, training=False):
outputs = self.pre(inputs, training=training)
@@ -178,13 +178,13 @@ def call(self, inputs, training=False):
outputs = self.post(outputs, training=training)
return outputs

def compute_output_shape(self, input_shape):
output_shape = input_shape
output_shape = self.pre.compute_output_shape(output_shape)
for conv in self.convs:
output_shape = conv.compute_output_shape(output_shape)
output_shape = self.post.compute_output_shape(output_shape)
return output_shape
# def compute_output_shape(self, input_shape):
# output_shape = input_shape
# output_shape = self.pre.compute_output_shape(output_shape)
# for conv in self.convs:
# output_shape = conv.compute_output_shape(output_shape)
# output_shape = self.post.compute_output_shape(output_shape)
# return output_shape


# ------------------------------------ RNN ----------------------------------- #
6 changes: 2 additions & 4 deletions tensorflow_asr/models/layers/blurpool.py
Original file line number Diff line number Diff line change
@@ -30,10 +30,9 @@ def __init__(
trainable=True,
name="blurpool2d",
dtype=None,
dynamic=False,
**kwargs,
):
super().__init__(trainable, name, dtype, dynamic, **kwargs)
super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
@@ -88,10 +87,9 @@ def __init__(
trainable=True,
name="blurpool1d",
dtype=None,
dynamic=False,
**kwargs,
):
super().__init__(trainable, name, dtype, dynamic, **kwargs)
super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
16 changes: 8 additions & 8 deletions tensorflow_asr/models/layers/embedding.py
Original file line number Diff line number Diff line change
@@ -52,10 +52,10 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape = super().compute_output_shape(output_shape)
return output_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape = super().compute_output_shape(output_shape)
# return output_shape, output_length_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -87,7 +87,7 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape = output_shape + (self.depth,)
return output_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape = output_shape + (self.depth,)
# return output_shape, output_length_shape
16 changes: 8 additions & 8 deletions tensorflow_asr/models/layers/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -312,11 +312,11 @@ def compute_mask(self, inputs, mask=None):
padded_nframes = self.get_nframes(tf.shape(signals, tf.int32)[1])
return tf.sequence_mask(nframes, maxlen=padded_nframes, dtype=tf.bool), None

def compute_output_shape(self, input_shape):
signal_shape, signal_length_shape = input_shape
B, nsamples = signal_shape
if nsamples is None:
output_shape = [B, None, self.num_feature_bins, 1]
else:
output_shape = [B, self.get_nframes(nsamples + self.padding), self.num_feature_bins, 1]
return tf.TensorShape(output_shape), tf.TensorShape(signal_length_shape)
# def compute_output_shape(self, input_shape):
# signal_shape, signal_length_shape = input_shape
# B, nsamples = signal_shape
# if nsamples is None:
# output_shape = [B, None, self.num_feature_bins, 1]
# else:
# output_shape = [B, self.get_nframes(nsamples + self.padding), self.num_feature_bins, 1]
# return tf.TensorShape(output_shape), tf.TensorShape(signal_length_shape)
8 changes: 4 additions & 4 deletions tensorflow_asr/models/layers/memory.py
Original file line number Diff line number Diff line change
@@ -79,8 +79,8 @@ def call(self, inputs, memories=None, training=False):
new_memory._keras_mask = new_memory_mask # pylint: disable=protected-access
return new_inputs, new_memory

def compute_output_shape(self, input_shape):
return input_shape[0], self.memory_length, self.dmodel
# def compute_output_shape(self, input_shape):
# return input_shape[0], self.memory_length, self.dmodel

def compute_output_spec(self, *args, **kwargs):
return super().compute_output_spec(*args, **kwargs)
# def compute_output_spec(self, *args, **kwargs):
# return super().compute_output_spec(*args, **kwargs)
45 changes: 37 additions & 8 deletions tensorflow_asr/models/layers/multihead_attention.py
Original file line number Diff line number Diff line change
@@ -244,6 +244,8 @@ def call(
training=None,
use_causal_mask=False,
initial_state=None,
return_states=False,
**kwargs,
):
query, key, value = inputs

@@ -269,14 +271,20 @@ def call(
# `value` = [B, S, N, H]
value = self._value_dense(value)

query, key, value, states = self._with_memory(query, key, value, initial_state, training)
if return_states:
query, key, value, states = self._with_memory(query, key, value, initial_state, training)

attention_output, attention_scores = self._compute_attention(query, key, value, attention_mask, training)
attention_output = self._output_dense(attention_output)

if return_attention_scores:
return attention_output, states, attention_scores
return attention_output, states
if return_states:
return attention_output, states, attention_scores
return attention_output, attention_scores

if return_states:
return attention_output, states
return (attention_output,)

def compute_output_shape(self, input_shape):
query_shape, key_shape, value_shape, *_ = input_shape
@@ -294,11 +302,22 @@ def compute_output_spec(
training=None,
use_causal_mask=False,
initial_state=None,
return_states=False,
):
query, value, key, *_ = inputs
output_spec, *attention_score_spec = super().compute_output_spec(
query, value, key, query_mask, value_mask, key_mask, attention_mask, return_attention_scores, training, use_causal_mask
)
if not return_states:
return [output_spec] + attention_score_spec
if self._memory_length is None:
return [output_spec, None] + attention_score_spec
states_shape = (query.shape[0], self._memory_length, query.shape[-1])
states_spec = {
"key": keras.KerasTensor(states_shape, dtype=self.compute_dtype),
"value": keras.KerasTensor(states_shape, dtype=self.compute_dtype),
}
return [output_spec, states_spec] + attention_score_spec


@keras.utils.register_keras_serializable(package=__name__)
@@ -348,7 +367,7 @@ def __init__(
self._causal = causal

def build(self, input_shape):
*rest_input_shape, relpe_shape, _, _ = input_shape
*rest_input_shape, relpe_shape = input_shape
relpe_rank = len(relpe_shape)
einsum_equation, bias_axes, output_rank = mha_module._build_proj_equation(relpe_rank - 1, bound_dims=1, output_dims=2)
self._relpe_dense = keras.layers.EinsumDense(
@@ -423,6 +442,8 @@ def _compute_attention(
def call(
self,
inputs,
content_attention_bias=None,
positional_attention_bias=None,
query_mask=None,
value_mask=None,
key_mask=None,
@@ -432,8 +453,10 @@ def call(
training=None,
use_causal_mask=False,
initial_state=None,
return_states=False,
**kwargs,
):
query, key, value, relpe, content_attention_bias, positional_attention_bias = inputs
query, key, value, relpe = inputs

if use_auto_mask:
attention_mask = self._compute_attention_mask(
@@ -460,7 +483,8 @@ def call(
# `position` = [B, R, N, H]
position = self._relpe_dense(relpe)

query, key, value, states = self._with_memory(query, key, value, initial_state, training)
if return_states:
query, key, value, states = self._with_memory(query, key, value, initial_state, training)

attention_output, attention_scores = self._compute_attention(
query,
@@ -475,5 +499,10 @@ def call(
attention_output = self._output_dense(attention_output)

if return_attention_scores:
return attention_output, states, attention_scores
return attention_output, states
if return_states:
return attention_output, states, attention_scores
return attention_output, attention_scores

if return_states:
return attention_output, states
return (attention_output,)
20 changes: 10 additions & 10 deletions tensorflow_asr/models/layers/positional_encoding.py
Original file line number Diff line number Diff line change
@@ -83,9 +83,9 @@ def call(self, inputs, training=False):
outputs += pe
return outputs, pe

def compute_output_shape(self, input_shape):
output_shape, _ = input_shape
return output_shape, output_shape
# def compute_output_shape(self, input_shape):
# output_shape, _ = input_shape
# return output_shape, output_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -172,10 +172,10 @@ def call(self, inputs, training=False):
pe = self.do(pe, training=training)
return outputs, pe

def compute_output_shape(self, input_shape):
output_shape, _ = input_shape
B, T, V = output_shape
pT = 2 * T - 1 if T is not None else None
if self._memory_length > 0 and T is not None:
pT += self._memory_length
return output_shape, (B, pT, V)
# def compute_output_shape(self, input_shape):
# output_shape, _ = input_shape
# B, T, V = output_shape
# pT = 2 * T - 1 if T is not None else None
# if self._memory_length > 0 and T is not None:
# pT += self._memory_length
# return output_shape, (B, pT, V)
8 changes: 4 additions & 4 deletions tensorflow_asr/models/layers/residual.py
Original file line number Diff line number Diff line change
@@ -52,14 +52,14 @@ def build(self, input_shape):
)
else:
assert isinstance(self._factor, (int, float))
self._alpha = tf.convert_to_tensor(self._factor, dtype=self.compute_dtype)
self._alpha = self._factor
return super().build(input_shape)

def call(self, inputs):
x, residual_x = inputs
alpha = tf.cast(self._alpha, residual_x.dtype)
alpha = tf.cast(tf.convert_to_tensor(self._alpha, dtype=self.dtype), residual_x.dtype)
x = x + alpha * residual_x
return x

def compute_output_shape(self, input_shape):
return input_shape[0]
# def compute_output_shape(self, input_shape):
# return input_shape[0]
54 changes: 27 additions & 27 deletions tensorflow_asr/models/layers/subsampling.py
Original file line number Diff line number Diff line change
@@ -45,11 +45,11 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
reduced_time = math_util.legacy_get_reduced_length(output_shape[1], self.time_reduction_factor)
output_shape = output_shape[:1] + (reduced_time,) + output_shape[2:]
return output_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# reduced_time = math_util.legacy_get_reduced_length(output_shape[1], self.time_reduction_factor)
# output_shape = output_shape[:1] + (reduced_time,) + output_shape[2:]
# return output_shape, output_length_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -141,16 +141,16 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
outputs_shape = self.conv1.compute_output_shape(output_shape)
outputs_shape = self.conv2.compute_output_shape(outputs_shape)
outputs_shape = self.maxpool1.compute_output_shape(outputs_shape)
outputs_shape = self.conv3.compute_output_shape(outputs_shape)
outputs_shape = self.conv4.compute_output_shape(outputs_shape)
outputs_shape = self.maxpool2.compute_output_shape(outputs_shape)
outputs_shape = outputs_shape[:2] + (outputs_shape[2] * outputs_shape[3],)
return outputs_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# outputs_shape = self.conv1.compute_output_shape(output_shape)
# outputs_shape = self.conv2.compute_output_shape(outputs_shape)
# outputs_shape = self.maxpool1.compute_output_shape(outputs_shape)
# outputs_shape = self.conv3.compute_output_shape(outputs_shape)
# outputs_shape = self.conv4.compute_output_shape(outputs_shape)
# outputs_shape = self.maxpool2.compute_output_shape(outputs_shape)
# outputs_shape = outputs_shape[:2] + (outputs_shape[2] * outputs_shape[3],)
# return outputs_shape, output_length_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -235,12 +235,12 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
for block in self.convs:
output_shape = block.layers[0].compute_output_shape(output_shape)
output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],)
return output_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# for block in self.convs:
# output_shape = block.layers[0].compute_output_shape(output_shape)
# output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],)
# return output_shape, output_length_shape


@keras.utils.register_keras_serializable(package=__name__)
@@ -325,9 +325,9 @@ def compute_mask(self, inputs, mask=None):
mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool)
return mask, None

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],)
for block in self.convs:
output_shape = block.layers[0].compute_output_shape(output_shape)
return output_shape, output_length_shape
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],)
# for block in self.convs:
# output_shape = block.layers[0].compute_output_shape(output_shape)
# return output_shape, output_length_shape
40 changes: 20 additions & 20 deletions tensorflow_asr/models/transducer/base_transducer.py
Original file line number Diff line number Diff line change
@@ -152,16 +152,16 @@ def call_next(self, inputs, previous_decoder_states):
def compute_mask(self, inputs, mask=None):
return self.label_encoder.compute_mask(inputs, mask=mask)

def compute_output_shape(self, input_shape):
output_shape, output_length_shape = input_shape
output_shape, output_length_shape = self.label_encoder.compute_output_shape((output_shape, output_length_shape))
for i, rnn in enumerate(self.rnns):
output_shape = (
self.projections[i].compute_output_shape(output_shape)
if self.projections[i] is not None
else rnn.compute_output_shape(output_shape)[0]
)
return tuple(output_shape), tuple(output_length_shape)
# def compute_output_shape(self, input_shape):
# output_shape, output_length_shape = input_shape
# output_shape, output_length_shape = self.label_encoder.compute_output_shape((output_shape, output_length_shape))
# for i, rnn in enumerate(self.rnns):
# output_shape = (
# self.projections[i].compute_output_shape(output_shape)
# if self.projections[i] is not None
# else rnn.compute_output_shape(output_shape)[0]
# )
# return tuple(output_shape), tuple(output_length_shape)


@keras.utils.register_keras_serializable(package=__name__)
@@ -197,9 +197,9 @@ def call(self, inputs):
outputs = tf.multiply(enc_out, pred_out) # broadcast operator
return outputs # [B, T, U, V]

def compute_output_shape(self, input_shape):
enc_shape, pred_shape = input_shape
return enc_shape[0], enc_shape[1], pred_shape[1], enc_shape[-1]
# def compute_output_shape(self, input_shape):
# enc_shape, pred_shape = input_shape
# return enc_shape[0], enc_shape[1], pred_shape[1], enc_shape[-1]


@keras.utils.register_keras_serializable(package=__name__)
@@ -281,11 +281,11 @@ def call(self, inputs, training=False):
def compute_mask(self, inputs, mask=None):
return self.joint.compute_mask(inputs, mask=mask)

def compute_output_shape(self, input_shape):
encoder_shape, prediction_shape = input_shape
batch_shape = encoder_shape[0]
encoder_time_shape, prediction_time_shape = encoder_shape[1], prediction_shape[1]
return batch_shape, encoder_time_shape, prediction_time_shape, self.ffn_out.units
# def compute_output_shape(self, input_shape):
# encoder_shape, prediction_shape = input_shape
# batch_shape = encoder_shape[0]
# encoder_time_shape, prediction_time_shape = encoder_shape[1], prediction_shape[1]
# return batch_shape, encoder_time_shape, prediction_time_shape, self.ffn_out.units


class Transducer(BaseModel):
@@ -407,8 +407,8 @@ def remove_gwn(self, original_weights):

def call(self, inputs: schemas.TrainInput, training=False):
features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=training)
enc, logits_length, _ = self.encoder((features, features_length), training=training)
pred, _ = self.predict_net((inputs.predictions, inputs.predictions_length), training=training)
enc, logits_length, *_ = self.encoder((features, features_length), training=training)
pred, *_ = self.predict_net((inputs.predictions, inputs.predictions_length), training=training)
logits = self.joint_net((enc, pred), training=training)
return schemas.TrainOutput(
logits=logits,
6 changes: 4 additions & 2 deletions tensorflow_asr/utils/layer_util.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from tensorflow_asr import keras, tf
from tensorflow_asr.models.layers.convolution import Conv2D

@@ -37,13 +39,13 @@ def get_conv(


def add_gwn(
trainable_weights: list,
trainable_weights: List[tf.Variable],
stddev: float = 1.0,
):
original_weights = []
for weight in trainable_weights:
noise = tf.stop_gradient(tf.random.normal(mean=0.0, stddev=stddev, shape=weight.shape, dtype=weight.dtype))
original_weights.append(weight.value())
original_weights.append(weight)
weight.assign_add(noise)
return original_weights

4 changes: 2 additions & 2 deletions tensorflow_asr/utils/shape_util.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,8 @@ def shape_list(x, out_type=tf.int32):


def shape_list_per_replica(x, per_replica_batch_size):
shapes = x.shape.as_list()
shapes[0] = int(per_replica_batch_size)
_, *rest_shape = x.shape
shapes = (int(per_replica_batch_size),) + tuple(rest_shape)
return shapes