Skip to content

Commit

Permalink
refactor(go): Moved model version validation to middleware (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
apascal07 authored Mar 5, 2025
1 parent a5ef0c1 commit a842e5e
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 144 deletions.
2 changes: 1 addition & 1 deletion go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func defineProgrammableModel(r *registry.Registry) *programmableModel {
Tools: true,
Multiturn: true,
}
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
Expand Down
110 changes: 32 additions & 78 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,44 @@ import (
"github.com/firebase/genkit/go/internal/registry"
)

// Model represents a model that can perform content generation tasks.
type Model interface {
// Name returns the registry name of the model.
Name() string
// Generate applies the [Model] to provided request, handling tool requests and handles streaming.
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
}
type (
// Model represents a model that can generate content based on a request.
Model interface {
// Name returns the registry name of the model.
Name() string
// Generate applies the [Model] to provided request, handling tool requests and handles streaming.
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamCallback) (*ModelResponse, error)
}

// ModelFunc is a function that generates a model response.
type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelResponseChunk]
// ToolConfig handles configuration around tool calls during generation.
ToolConfig struct {
MaxTurns int // Maximum number of tool call iterations before erroring.
ReturnToolRequests bool // Whether to return tool requests instead of making the tool calls and continuing the generation.
}

// ModelMiddleware is middleware for model generate requests.
type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk]
// ModelFunc is a streaming function that takes in a ModelRequest and generates a ModelResponse, optionally streaming ModelResponseChunks.
ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ModelAction is an action for model generation.
type ModelAction = core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk]
// ModelStreamCallback is a stream callback of a ModelAction.
ModelStreamCallback = func(context.Context, *ModelResponseChunk) error

type generateAction = core.ActionDef[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk]
// ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc.
ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelActionDef core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk]
// ModelAction is the type for model generation actions.
ModelAction = core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error
// modelActionDef is an action with functions specific to model generation such as Generate().
modelActionDef core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ToolConfig handles configuration around tool calls during generation.
type ToolConfig struct {
MaxTurns int
ReturnToolRequests bool
}
// generateAction is the type for a utility model generation action that takes in a GenerateActionOptions instead of a ModelRequest.
generateAction = core.ActionDef[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk]
)

// DefineGenerateAction defines a utility generate action.
func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction {
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, nil,
func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamingCallback) (output *ModelResponse, err error) {
func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamCallback) (output *ModelResponse, err error) {
logger.FromContext(ctx).Debug("GenerateAction",
"input", fmt.Sprintf("%#v", req))
defer func() {
Expand Down Expand Up @@ -137,7 +141,7 @@ func DefineModel(
metadataMap["supports"] = supports
metadataMap["versions"] = info.Versions

generate = core.ChainMiddleware(ValidateSupport(name, info.Supports))(generate)
generate = core.ChainMiddleware(ValidateSupport(name, info))(generate)

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, generate))
}
Expand All @@ -161,7 +165,7 @@ func LookupModel(r *registry.Registry, provider, name string) Model {
type generateParams struct {
Request *ModelRequest
Model Model
Stream ModelStreamingCallback
Stream ModelStreamCallback
History []*Message
SystemPrompt *Message
MaxTurns int
Expand Down Expand Up @@ -287,7 +291,7 @@ func WithOutputFormat(format OutputFormat) GenerateOption {
}

// WithStreaming adds a streaming callback to the generate request.
func WithStreaming(cb ModelStreamingCallback) GenerateOption {
func WithStreaming(cb ModelStreamCallback) GenerateOption {
return func(req *generateParams) error {
if req.Stream != nil {
return errors.New("generate.WithStreaming: cannot set streaming callback more than once")
Expand Down Expand Up @@ -361,18 +365,6 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
return nil, errors.New("model is required")
}

var modelVersion string
if config, ok := req.Request.Config.(*GenerationCommonConfig); ok {
modelVersion = config.Version
}

if modelVersion != "" {
ok, err := validateModelVersion(r, modelVersion, req)
if !ok {
return nil, err
}
}

if req.History != nil {
prev := req.Request.Messages
req.Request.Messages = req.History
Expand All @@ -395,44 +387,6 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
return req.Model.Generate(ctx, r, req.Request, req.Middleware, toolCfg, req.Stream)
}

// validateModelVersion checks in the registry the action of the
// given model version and determines whether its supported or not.
func validateModelVersion(r *registry.Registry, v string, req *generateParams) (bool, error) {
parts := strings.Split(req.Model.Name(), "/")
if len(parts) != 2 {
return false, errors.New("wrong model name")
}

m := LookupModel(r, parts[0], parts[1])
if m == nil {
return false, fmt.Errorf("model %s not found", v)
}

// at the end, a Model is an action so type conversion is required
if a, ok := m.(*modelActionDef); ok {
if !(modelVersionSupported(v, (*ModelAction)(a).Desc().Metadata)) {
return false, fmt.Errorf("version %s not supported", v)
}
} else {
return false, errors.New("unable to validate model version")
}

return true, nil
}

// modelVersionSupported iterates over model's metadata to find the requested
// supported model version
func modelVersionSupported(modelVersion string, modelMetadata map[string]any) bool {
if md, ok := modelMetadata["model"].(map[string]any); ok {
for _, v := range md["versions"].([]string) {
if modelVersion == v {
return true
}
}
}
return false
}

// GenerateText run generate request for this model. Returns generated text only.
func GenerateText(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (string, error) {
res, err := Generate(ctx, r, opts...)
Expand All @@ -459,7 +413,7 @@ func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ...
}

// Generate applies the [Action] to provided request, handling tool requests and handles streaming.
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamCallback) (*ModelResponse, error) {
if m == nil {
return nil, errors.New("Generate called on a nil Model; check that all models are defined")
}
Expand Down Expand Up @@ -557,7 +511,7 @@ func cloneMessage(m *Message) *Message {
// handleToolRequests processes any tool requests in the response, returning
// either a new request to continue the conversation or nil if no tool requests
// need handling.
func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamingCallback) (*ModelRequest, *Message, error) {
func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback) (*ModelRequest, *Message, error) {
toolCount := 0
for _, part := range resp.Message.Content {
if part.IsToolRequest() {
Expand Down
12 changes: 6 additions & 6 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var (
Versions: []string{"echo-001", "echo-002"},
}

echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestGenerate(t *testing.T) {
},
}
interruptModel := DefineModel(r, "test", "interrupt", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Message: &Message{
Expand Down Expand Up @@ -412,7 +412,7 @@ func TestGenerate(t *testing.T) {
},
}
parallelModel := DefineModel(r, "test", "parallel", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
return &ModelResponse{
Expand Down Expand Up @@ -477,7 +477,7 @@ func TestGenerate(t *testing.T) {
},
}
multiRoundModel := DefineModel(r, "test", "multiround", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
return &ModelResponse{
Expand Down Expand Up @@ -545,7 +545,7 @@ func TestGenerate(t *testing.T) {
},
}
infiniteModel := DefineModel(r, "test", "infinite", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Message: &Message{
Expand Down Expand Up @@ -578,7 +578,7 @@ func TestGenerate(t *testing.T) {
t.Run("applies middleware", func(t *testing.T) {
middlewareCalled := false
testMiddleware := func(next ModelFunc) ModelFunc {
return func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
middlewareCalled = true
req.Messages = append(req.Messages, NewUserTextMessage("middleware was here"))
return next(ctx, req, cb)
Expand Down
61 changes: 52 additions & 9 deletions go/ai/model_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ package ai

import (
"context"
"encoding/json"
"fmt"
"slices"
)

// ValidateSupport creates middleware that validates whether a model supports the requested features.
func ValidateSupport(model string, supports *ModelInfoSupports) ModelMiddleware {
func ValidateSupport(model string, info *ModelInfo) ModelMiddleware {
return func(next ModelFunc) ModelFunc {
return func(ctx context.Context, input *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
if supports == nil {
supports = &ModelInfoSupports{}
return func(ctx context.Context, input *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
if info == nil {
info = &ModelInfo{
Supports: &ModelInfoSupports{},
Versions: []string{},
}
}

if !supports.Media {
if !info.Supports.Media {
for _, msg := range input.Messages {
for _, part := range msg.Content {
if part.IsMedia() {
Expand All @@ -26,29 +31,67 @@ func ValidateSupport(model string, supports *ModelInfoSupports) ModelMiddleware
}
}

if !supports.Tools && len(input.Tools) > 0 {
if !info.Supports.Tools && len(input.Tools) > 0 {
return nil, fmt.Errorf("model %q does not support tool use, but tools were provided. Request: %+v", model, input)
}

if !supports.Multiturn && len(input.Messages) > 1 {
if !info.Supports.Multiturn && len(input.Messages) > 1 {
return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", model, len(input.Messages), input)
}

if !supports.ToolChoice && input.ToolChoice != "" && input.ToolChoice != ToolChoiceAuto {
if !info.Supports.ToolChoice && input.ToolChoice != "" && input.ToolChoice != ToolChoiceAuto {
return nil, fmt.Errorf("model %q does not support tool choice, but tool choice was provided. Request: %+v", model, input)
}

if !supports.SystemRole {
if !info.Supports.SystemRole {
for _, msg := range input.Messages {
if msg.Role == RoleSystem {
return nil, fmt.Errorf("model %q does not support system role, but system role was provided. Request: %+v", model, input)
}
}
}

if err := validateVersion(model, info.Versions, input.Config); err != nil {
return nil, err
}

// TODO: Add validation for features that won't have simulated support via middleware.

return next(ctx, input, cb)
}
}
}

// validateVersion validates that the requested model version is supported.
func validateVersion(model string, versions []string, config any) error {
var configMap map[string]any

switch c := config.(type) {
case map[string]any:
configMap = c
default:
data, err := json.Marshal(config)
if err != nil {
return nil
}
if err := json.Unmarshal(data, &configMap); err != nil {
return nil
}
}

versionVal, exists := configMap["version"]
if !exists {
return nil
}

version, ok := versionVal.(string)
if !ok {
return fmt.Errorf("version must be a string, got %T", versionVal)
}

if slices.Contains(versions, version) {
return nil
}

return fmt.Errorf("model %q does not support version %q, supported versions: %v", model, version, versions)
}
Loading

0 comments on commit a842e5e

Please sign in to comment.