Skip to content

Commit

Permalink
Refactor to deprecate _get_tts override and use synth method dire…
Browse files Browse the repository at this point in the history
…ctly
  • Loading branch information
NeonDaniel committed Feb 12, 2025
1 parent 222b64b commit 600c16e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
15 changes: 11 additions & 4 deletions neon_audio/tts/neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@
import os

from os.path import dirname
from queue import Empty
from time import time
from typing import List

from json_database import JsonStorageXDG
from ovos_bus_client.apis.enclosure import EnclosureAPI
from ovos_bus_client.message import Message
from ovos_bus_client.util import get_message_lang
from ovos_plugin_manager.language import OVOSLangDetectionFactory,\
OVOSLangTranslationFactory
from ovos_plugin_manager.templates.tts import TTS
from ovos_plugin_manager.templates.g2p import OutOfVocabulary
from ovos_plugin_manager.templates.tts import TTS, TTSContext
from ovos_utils.sound import play_audio

from neon_utils.file_utils import encode_file_to_base64_string
from neon_utils.message_utils import resolve_message
Expand All @@ -50,7 +55,7 @@
from ovos_config.config import Configuration


def get_requested_tts_languages(msg) -> list:
def get_requested_tts_languages(msg) -> List[dict]:
"""
Builds a list of the requested TTS for a given spoken response
:param msg: Message associated with request
Expand Down Expand Up @@ -212,7 +217,7 @@ def __new__(cls, base_engine, *args, **kwargs):
base_engine.execute = cls.execute
base_engine.get_multiple_tts = cls.get_multiple_tts
# TODO: Below method is only to bridge compatibility
base_engine._get_tts = cls._get_tts
# base_engine._get_tts = cls._get_tts
base_engine._init_playback = cls._init_playback
base_engine.lang = cls.lang
return cls._init_neon(base_engine, *args, **kwargs)
Expand Down Expand Up @@ -293,6 +298,7 @@ def _get_tts(self, sentence: str, request: dict = None, **kwargs):
os.makedirs(dirname(file), exist_ok=True)
if os.path.isfile(file):
LOG.info(f"Using cached TTS audio")
# TODO: In this case, playback is not reported properly
return file, None
plugin_kwargs = dict()
if "speaker" in inspect.signature(self.get_tts).parameters:
Expand Down Expand Up @@ -336,7 +342,8 @@ def get_multiple_tts(self, message, **kwargs) -> dict:
LOG.info(f"Got translated sentence: {tx_sentence}")
else:
tx_sentence = sentence
wav_file, phonemes = self._get_tts(tx_sentence, request, **kwargs)
kwargs['speaker'] = request
wav_file, phonemes = self.synth(tx_sentence, **kwargs)

# If this is the first response, populate translation and phonemes
responses.setdefault(tts_lang, {"sentence": tx_sentence,
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_validator_invalid(self):
tts.shutdown()

def test_get_tts(self):
# TODO: Deprecate
test_file_path = join(dirname(__file__), "test.wav")
file, phonemes = self.tts._get_tts("test", wav_file=test_file_path,
speaker={})
Expand Down

0 comments on commit 600c16e

Please sign in to comment.