Skip to content

Commit

Permalink
[InferenceClient] Add text-to-video task and update supported tasks…
Browse files Browse the repository at this point in the history
… and models (#2786)

* update chat completion docstring

* add tasks and update supported models

* fix docstrings and update error message

* fix test
  • Loading branch information
hanouticelina authored Jan 27, 2025
1 parent 803fa7b commit a259e88
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 91 deletions.
73 changes: 62 additions & 11 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ def chat_completion(
</Tip>
<Tip>
Some parameters might not be supported by some providers.
</Tip>
Args:
messages (List of [`ChatCompletionInputMessage`]):
Conversation history consisting of roles and content pairs.
Expand All @@ -628,14 +632,14 @@ def chat_completion(
Penalizes new tokens based on their existing frequency
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
logit_bias (`List[float]`, *optional*):
UNUSED. Currently not implemented in text-generation-inference (TGI). Kept as a parameter for OpenAI compatibility.
Adjusts the likelihood of specific tokens appearing in the generated output.
logprobs (`bool`, *optional*):
Whether to return log probabilities of the output tokens or not. If true, returns the log
probabilities of each output token returned in the content of message.
max_tokens (`int`, *optional*):
Maximum number of tokens allowed in the response. Defaults to 100.
n (`int`, *optional*):
UNUSED. Currently not implemented in text-generation-inference (TGI). Kept as a parameter for OpenAI compatibility.
The number of completions to generate for each prompt.
presence_penalty (`float`, *optional*):
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
text so far, increasing the model's likelihood to talk about new topics.
Expand Down Expand Up @@ -2054,15 +2058,6 @@ def text_generation(
"""
Given a prompt, generate the following text.
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
not exactly the same. This method is compatible with both approaches but some parameters are only available for
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
continues correctly.
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
<Tip>
If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
Expand Down Expand Up @@ -2470,6 +2465,61 @@ def text_to_image(
response = provider_helper.get_response(response)
return _bytes_to_image(response)

def text_to_video(
self,
prompt: str,
*,
model: Optional[str] = None,
guidance_scale: Optional[float] = None,
negative_prompt: Optional[List[str]] = None,
num_frames: Optional[float] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
) -> bytes:
"""
Generate a video based on a given text.
Args:
prompt (`str`):
The prompt to generate a video from.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. If not provided, the default recommended text-to-video model will be used.
Defaults to None.
guidance_scale (`float`, *optional*):
A higher guidance scale value encourages the model to generate videos closely linked to the text
prompt, but values too high may cause saturation and other artifacts.
negative_prompt (`List[str]`, *optional*):
One or several prompt to guide what NOT to include in video generation.
num_frames (`float`, *optional*):
The num_frames parameter determines how many video frames are generated.
num_inference_steps (`int`, *optional*):
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference.
seed (`int`, *optional*):
Seed for the random number generator.
Returns:
`bytes`: The generated video.
"""
provider_helper = get_provider_helper(self.provider, task="text-to-video")
request_parameters = provider_helper.prepare_request(
inputs=prompt,
parameters={
"guidance_scale": guidance_scale,
"negative_prompt": negative_prompt,
"num_frames": num_frames,
"num_inference_steps": num_inference_steps,
"seed": seed,
},
headers=self.headers,
model=model or self.model,
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response)
return response

def text_to_speech(
self,
text: str,
Expand Down Expand Up @@ -2594,6 +2644,7 @@ def text_to_speech(
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response)
return response

def token_classification(
Expand Down
40 changes: 3 additions & 37 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,45 +258,11 @@ def _bytes_to_image(content: bytes) -> "Image":
return Image.open(io.BytesIO(content))


## PAYLOAD UTILS
def _as_dict(response: Union[bytes, Dict]) -> Dict:
return json.loads(response) if isinstance(response, bytes) else response


def _prepare_payload(
inputs: Union[str, Dict[str, Any], ContentT],
parameters: Optional[Dict[str, Any]],
expect_binary: bool = False,
) -> Dict[str, Any]:
"""
Used in `InferenceClient` and `AsyncInferenceClient` to prepare the payload for an API request, handling various input types and parameters.
`expect_binary` is set to `True` when the inputs are a binary object or a local path or URL. This is the case for image and audio inputs.
"""
if parameters is None:
parameters = {}
parameters = {k: v for k, v in parameters.items() if v is not None}
has_parameters = len(parameters) > 0

is_binary = isinstance(inputs, (bytes, Path))
# If expect_binary is True, inputs must be a binary object or a local path or a URL.
if expect_binary and not is_binary and not isinstance(inputs, str):
raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}") # type: ignore
# Send inputs as raw content when no parameters are provided
if expect_binary and not has_parameters:
return {"data": inputs}
# If expect_binary is False, inputs must not be a binary object.
if not expect_binary and is_binary:
raise ValueError(f"Unexpected binary inputs. Got {inputs}") # type: ignore

json: Dict[str, Any] = {}
# If inputs is a bytes-like object, encode it to base64
if expect_binary:
json["inputs"] = _b64_encode(inputs) # type: ignore
# Otherwise (string, dict, list) send it as is
else:
json["inputs"] = inputs
# Add parameters to the json payload if any
if has_parameters:
json["parameters"] = parameters
return {"json": json}
## PAYLOAD UTILS


## STREAMING UTILS
Expand Down
73 changes: 62 additions & 11 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,10 @@ async def chat_completion(
</Tip>
<Tip>
Some parameters might not be supported by some providers.
</Tip>
Args:
messages (List of [`ChatCompletionInputMessage`]):
Conversation history consisting of roles and content pairs.
Expand All @@ -665,14 +669,14 @@ async def chat_completion(
Penalizes new tokens based on their existing frequency
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
logit_bias (`List[float]`, *optional*):
UNUSED. Currently not implemented in text-generation-inference (TGI). Kept as a parameter for OpenAI compatibility.
Adjusts the likelihood of specific tokens appearing in the generated output.
logprobs (`bool`, *optional*):
Whether to return log probabilities of the output tokens or not. If true, returns the log
probabilities of each output token returned in the content of message.
max_tokens (`int`, *optional*):
Maximum number of tokens allowed in the response. Defaults to 100.
n (`int`, *optional*):
UNUSED. Currently not implemented in text-generation-inference (TGI). Kept as a parameter for OpenAI compatibility.
The number of completions to generate for each prompt.
presence_penalty (`float`, *optional*):
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
text so far, increasing the model's likelihood to talk about new topics.
Expand Down Expand Up @@ -2112,15 +2116,6 @@ async def text_generation(
"""
Given a prompt, generate the following text.
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
not exactly the same. This method is compatible with both approaches but some parameters are only available for
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
continues correctly.
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
<Tip>
If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method.
Expand Down Expand Up @@ -2530,6 +2525,61 @@ async def text_to_image(
response = provider_helper.get_response(response)
return _bytes_to_image(response)

async def text_to_video(
self,
prompt: str,
*,
model: Optional[str] = None,
guidance_scale: Optional[float] = None,
negative_prompt: Optional[List[str]] = None,
num_frames: Optional[float] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
) -> bytes:
"""
Generate a video based on a given text.
Args:
prompt (`str`):
The prompt to generate a video from.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. If not provided, the default recommended text-to-video model will be used.
Defaults to None.
guidance_scale (`float`, *optional*):
A higher guidance scale value encourages the model to generate videos closely linked to the text
prompt, but values too high may cause saturation and other artifacts.
negative_prompt (`List[str]`, *optional*):
One or several prompt to guide what NOT to include in video generation.
num_frames (`float`, *optional*):
The num_frames parameter determines how many video frames are generated.
num_inference_steps (`int`, *optional*):
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference.
seed (`int`, *optional*):
Seed for the random number generator.
Returns:
`bytes`: The generated video.
"""
provider_helper = get_provider_helper(self.provider, task="text-to-video")
request_parameters = provider_helper.prepare_request(
inputs=prompt,
parameters={
"guidance_scale": guidance_scale,
"negative_prompt": negative_prompt,
"num_frames": num_frames,
"num_inference_steps": num_inference_steps,
"seed": seed,
},
headers=self.headers,
model=model or self.model,
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response)
return response

async def text_to_speech(
self,
text: str,
Expand Down Expand Up @@ -2655,6 +2705,7 @@ async def text_to_speech(
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response)
return response

async def token_classification(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ class TextToVideoParameters(BaseInferenceType):
"""Additional inference parameters for Text To Video"""

guidance_scale: Optional[float] = None
"""A higher guidance scale value encourages the model to generate images closely linked to
"""A higher guidance scale value encourages the model to generate videos closely linked to
the text prompt, but values too high may cause saturation and other artifacts.
"""
negative_prompt: Optional[List[str]] = None
"""One or several prompt to guide what NOT to include in image generation."""
"""One or several prompt to guide what NOT to include in video generation."""
num_frames: Optional[float] = None
"""The num_frames parameter determines how many video frames are generated."""
num_inference_steps: Optional[int] = None
"""The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
video at the expense of slower inference.
"""
seed: Optional[int] = None
"""Seed for the random number generator."""
Expand Down
9 changes: 6 additions & 3 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict

from .._common import TaskProviderHelper
from .fal_ai import FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask
from .fal_ai import FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask, FalAITextToVideoTask
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
from .replicate import ReplicateTextToImageTask
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
from .sambanova import SambanovaConversationalTask
from .together import TogetherTextGenerationTask, TogetherTextToImageTask

Expand All @@ -12,6 +12,7 @@
"fal-ai": {
"text-to-image": FalAITextToImageTask(),
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
"text-to-video": FalAITextToVideoTask(),
},
"hf-inference": {
"text-to-image": HFInferenceTask("text-to-image"),
Expand Down Expand Up @@ -42,7 +43,9 @@
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
},
"replicate": {
"text-to-image": ReplicateTextToImageTask(),
"text-to-image": ReplicateTask("text-to-image"),
"text-to-speech": ReplicateTextToSpeechTask(),
"text-to-video": ReplicateTask("text-to-video"),
},
"sambanova": {
"conversational": SambanovaConversationalTask(),
Expand Down
28 changes: 24 additions & 4 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import base64
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.utils import build_hf_headers, get_session, logging


Expand All @@ -20,6 +19,18 @@
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
"Warlord-K/Sana-1024": "fal-ai/sana",
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
"Kwai-Kolors/Kolors": "fal-ai/kolors",
},
"text-to-video": {
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
},
}

Expand Down Expand Up @@ -131,5 +142,14 @@ def get_response(self, response: Union[bytes, Dict]) -> Any:
return get_session().get(url).content


def _as_dict(response: Union[bytes, Dict]) -> Dict:
return json.loads(response) if isinstance(response, bytes) else response
class FalAITextToVideoTask(FalAITask):
def __init__(self):
super().__init__("text-to-video")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
return {"prompt": inputs, **parameters}

def get_response(self, response: Union[bytes, Dict]) -> Any:
url = _as_dict(response)["video"]["url"]
return get_session().get(url).content
Loading

0 comments on commit a259e88

Please sign in to comment.