diff --git a/go/go.mod b/go/go.mod index 677dcc6bb7..bd901c6a5e 100644 --- a/go/go.mod +++ b/go/go.mod @@ -15,6 +15,7 @@ require ( firebase.google.com/go/v4 v4.14.1 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0 + github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.10 github.com/aymerick/raymond v2.0.2+incompatible github.com/google/generative-ai-go v0.16.1-0.20240711222609-09946422abc6 github.com/google/go-cmp v0.6.0 @@ -81,6 +82,10 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect diff --git a/go/go.sum b/go/go.sum index cc53089587..1c8703a4b2 100644 --- a/go/go.sum +++ b/go/go.sum @@ -48,6 +48,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/ankane/disco-go v0.1.0 h1:nkz+y4O+UFKnEGH8FkJ8wcVwX5boZvaRzJN6EMK7NVw= github.com/ankane/disco-go v0.1.0/go.mod h1:nkR7DLW+KkXeRRAsWk6poMTpTOWp9/4iKYGDwg8dSS0= +github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.10 h1:myWicO7qECViRePrrsSijlakZK3q7vzHBCoS2hL+8V0= +github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.10/go.mod h1:GJxtdOs9K4neo8Gg65CjJ7jNautmldGli5/OFNabOoo= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= @@ -258,7 +260,17 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ= diff --git a/go/plugins/vertexai/modelgarden/anthropic/anthropic.go b/go/plugins/vertexai/modelgarden/anthropic/anthropic.go new file mode 100644 index 0000000000..4dd52ebfd2 --- /dev/null +++ b/go/plugins/vertexai/modelgarden/anthropic/anthropic.go @@ -0,0 +1,368 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package anthropic + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "regexp" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/internal/gemini" + "github.com/firebase/genkit/go/plugins/internal/uri" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden/client" + "github.com/invopop/jsonschema" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/vertex" +) + +const ( + ProviderName = "anthropic" + MaxNumberOfTokens = 8192 + ToolNameRegex = `^[a-zA-Z0-9_-]{1,64}$` +) + +// supported anthropic models +var AnthropicModels = map[string]ai.ModelInfo{ + "claude-3-5-sonnet-v2": { + Label: "Vertex AI Model Garden - Claude 3.5 Sonnet", + Supports: &gemini.Multimodal, + Versions: []string{"claude-3-5-sonnet-v2@20241022"}, + }, + "claude-3-5-sonnet": { + Label: "Vertex AI Model Garden - Claude 3.5 Sonnet", + Supports: &gemini.Multimodal, + Versions: []string{"claude-3-5-sonnet@20240620"}, + }, + "claude-3-sonnet": { + Label: "Vertex AI Model Garden - Claude 3 Sonnet", + Supports: &gemini.Multimodal, + Versions: []string{"claude-3-sonnet@20240229"}, + }, + "claude-3-haiku": { + Label: "Vertex AI Model Garden - Claude 3 Haiku", + Supports: &gemini.Multimodal, + Versions: []string{"claude-3-haiku@20240307"}, + }, + "claude-3-opus": { + Label: "Vertex AI Model Garden - Claude 3 Opus", + Supports: &gemini.Multimodal, + Versions: []string{"claude-3-opus@20240229"}, + }, + "claude-3-7-sonnet": { + Label: "Vertex AI Model Garden - Claude 3.7 Sonnet", + Supports: &gemini.Multimodal, + Versions: []string{"claude-3-7-sonnet@20250219"}, + }, +} + +// AnthropicClientConfig is the required configuration to create an Anthropic +// client +type AnthropicClientConfig struct { + client.ClientConfig + // expand configuration as required +} + +// AnthropicClient is a mirror struct of Anthropic's client but implements +// [Client] interface +type AnthropicClient struct { + *anthropic.Client +} + +// Anthropic defines how an Anthropic client is created +var Anthropic = func(config any) (client.Client, error) { + cfg, ok := config.(*client.ClientConfig) + if !ok { + return nil, fmt.Errorf("invalid config for Anthropic %T", config) + } + c := anthropic.NewClient( + vertex.WithGoogleAuth(context.Background(), cfg.Location, cfg.Project), + ) + + return &AnthropicClient{c}, nil +} + +// DefineModel adds the model to the registry +func (a *AnthropicClient) DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, error) { + var mi ai.ModelInfo + if info == nil { + var ok bool + mi, ok = AnthropicModels[name] + if !ok { + return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelInfo", ProviderName, name) + } + } else { + mi = *info + } + return defineModel(g, a, name, mi), nil +} + +func defineModel(g *genkit.Genkit, client *AnthropicClient, name string, info ai.ModelInfo) ai.Model { + meta := &ai.ModelInfo{ + Label: ProviderName + "-" + name, + Supports: info.Supports, + Versions: info.Versions, + } + return genkit.DefineModel(g, ProviderName, name, meta, func( + ctx context.Context, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { + return generate(ctx, client, name, input, cb) + }) +} + +// generate function defines how a generate request is done in Anthropic models +func generate( + ctx context.Context, + client *AnthropicClient, + model string, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { + req, err := toAnthropicRequest(model, input) + if err != nil { + panic(fmt.Sprintf("unable to generate anthropic request: %v", err)) + } + + // no streaming + if cb == nil { + msg, err := client.Messages.New(ctx, req) + if err != nil { + return nil, err + } + + r := toGenkitResponse(msg) + r.Request = input + + return r, nil + } else { + stream := client.Messages.NewStreaming(ctx, req) + message := anthropic.Message{} + for stream.Next() { + event := stream.Current() + err := message.Accumulate(event) + if err != nil { + panic(err) + } + + switch event := event.AsUnion().(type) { + case anthropic.ContentBlockDeltaEvent: + cb(ctx, &ai.ModelResponseChunk{ + Content: []*ai.Part{ + { + Text: event.Delta.Text, + }, + }, + }) + case anthropic.MessageStopEvent: + r := toGenkitResponse(&message) + r.Request = input + return r, nil + } + } + if stream.Err() != nil { + panic(stream.Err()) + } + } + + return nil, nil +} + +func toAnthropicRole(role ai.Role) anthropic.MessageParamRole { + switch role { + case ai.RoleUser: + return anthropic.MessageParamRoleUser + case ai.RoleModel: + return anthropic.MessageParamRoleAssistant + case ai.RoleTool: + return anthropic.MessageParamRoleAssistant + default: + panic(fmt.Sprintf("unsupported role type: %v", role)) + } +} + +// toAnthropicRequest translates [ai.ModelRequest] to an Anthropic request +func toAnthropicRequest(model string, i *ai.ModelRequest) (anthropic.MessageNewParams, error) { + req := anthropic.MessageNewParams{} + messages := make([]anthropic.MessageParam, 0) + + // minimum required data to perform a request + req.Model = anthropic.F(anthropic.Model(model)) + req.MaxTokens = anthropic.F(int64(MaxNumberOfTokens)) + + if c, ok := i.Config.(*ai.GenerationCommonConfig); ok && c != nil { + if c.MaxOutputTokens != 0 { + req.MaxTokens = anthropic.F(int64(c.MaxOutputTokens)) + } + req.Model = anthropic.F(anthropic.Model(model)) + if c.Version != "" { + req.Model = anthropic.F(anthropic.Model(c.Version)) + } + if c.Temperature != 0 { + req.Temperature = anthropic.F(c.Temperature) + } + if c.TopK != 0 { + req.TopK = anthropic.F(int64(c.TopK)) + } + if c.TopP != 0 { + req.TopP = anthropic.F(float64(c.TopP)) + } + if len(c.StopSequences) > 0 { + req.StopSequences = anthropic.F(c.StopSequences) + } + } + + // configure system prompt (if given) + sysBlocks := []anthropic.TextBlockParam{} + for _, message := range i.Messages { + if message.Role == ai.RoleSystem { + // only text is supported for system messages + sysBlocks = append(sysBlocks, anthropic.NewTextBlock(message.Text())) + } else if message.Content[len(message.Content)-1].IsToolResponse() { + // if the last message is a ToolResponse, the conversation must continue + // and the ToolResponse message must be sent as a user + // see: https://docs.anthropic.com/en/docs/build-with-claude/tool-use#handling-tool-use-and-tool-result-content-blocks + parts, err := convertParts(message.Content) + if err != nil { + return req, err + } + messages = append(messages, anthropic.NewUserMessage(parts...)) + } else { + // handle the rest of the messages + parts, err := convertParts(message.Content) + if err != nil { + return req, err + } + messages = append(messages, anthropic.MessageParam{ + Role: anthropic.F(toAnthropicRole(message.Role)), + Content: anthropic.F(parts), + }) + } + } + + req.System = anthropic.F(sysBlocks) + req.Messages = anthropic.F(messages) + + // check tools + tools, err := convertTools(i.Tools) + if err != nil { + return req, err + } + req.Tools = anthropic.F(tools) + + return req, nil +} + +// convertTools translates [ai.ToolDefinition] to an anthropic.ToolParam type +func convertTools(tools []*ai.ToolDefinition) ([]anthropic.ToolParam, error) { + resp := make([]anthropic.ToolParam, 0) + regex := regexp.MustCompile(ToolNameRegex) + + for _, t := range tools { + if t.Name == "" { + return nil, fmt.Errorf("tool name is required") + } + if !regex.MatchString(t.Name) { + return nil, fmt.Errorf("tool name must match regex: %s", ToolNameRegex) + } + + resp = append(resp, anthropic.ToolParam{ + Name: anthropic.F(t.Name), + Description: anthropic.F(t.Description), + InputSchema: anthropic.F(generateSchema[map[string]any]()), + }) + } + + return resp, nil +} + +func generateSchema[T any]() interface{} { + reflector := jsonschema.Reflector{ + AllowAdditionalProperties: false, + DoNotReference: true, + } + var v T + return reflector.Reflect(v) +} + +// convertParts translates [ai.Part] to an anthropic.ContentBlockParamUnion type +func convertParts(parts []*ai.Part) ([]anthropic.ContentBlockParamUnion, error) { + blocks := []anthropic.ContentBlockParamUnion{} + + for _, p := range parts { + switch { + case p.IsText(): + blocks = append(blocks, anthropic.NewTextBlock(p.Text)) + case p.IsMedia(): + contentType, data, _ := uri.Data(p) + blocks = append(blocks, anthropic.NewImageBlockBase64(contentType, base64.StdEncoding.EncodeToString(data))) + case p.IsData(): + // todo: what is this? is this related to ContentBlocks? + panic("data content is unsupported by anthropic models") + case p.IsToolRequest(): + toolReq := p.ToolRequest + blocks = append(blocks, anthropic.NewToolUseBlockParam(toolReq.Ref, toolReq.Name, toolReq.Input)) + case p.IsToolResponse(): + toolResp := p.ToolResponse + output, err := json.Marshal(toolResp.Output) + if err != nil { + panic(fmt.Sprintf("unable to parse tool response: %v", err)) + } + blocks = append(blocks, anthropic.NewToolResultBlock(toolResp.Ref, string(output), false)) + default: + panic("unknown part type in the request") + } + } + + return blocks, nil +} + +// toGenkitResponse translates an Anthropic Message to [ai.ModelResponse] +func toGenkitResponse(m *anthropic.Message) *ai.ModelResponse { + r := &ai.ModelResponse{} + + switch m.StopReason { + case anthropic.MessageStopReasonMaxTokens: + r.FinishReason = ai.FinishReasonLength + case anthropic.MessageStopReasonStopSequence: + r.FinishReason = ai.FinishReasonStop + case anthropic.MessageStopReasonEndTurn: + r.FinishReason = ai.FinishReasonStop + case anthropic.MessageStopReasonToolUse: + r.FinishReason = ai.FinishReasonStop + default: + r.FinishReason = ai.FinishReasonUnknown + } + + msg := &ai.Message{} + msg.Role = ai.RoleModel + for _, part := range m.Content { + var p *ai.Part + switch part.Type { + case anthropic.ContentBlockTypeText: + p = ai.NewTextPart(string(part.Text)) + case anthropic.ContentBlockTypeToolUse: + p = ai.NewToolRequestPart(&ai.ToolRequest{ + Ref: part.ID, + Input: part.Input, + Name: part.Name, + }) + default: + panic(fmt.Sprintf("unknown part: %#v", part)) + } + msg.Content = append(msg.Content, p) + } + + r.Message = msg + r.Usage = &ai.GenerationUsage{ + InputTokens: int(m.Usage.InputTokens), + OutputTokens: int(m.Usage.OutputTokens), + } + return r +} diff --git a/go/plugins/vertexai/modelgarden/client/client.go b/go/plugins/vertexai/modelgarden/client/client.go new file mode 100644 index 0000000000..573ca6ddde --- /dev/null +++ b/go/plugins/vertexai/modelgarden/client/client.go @@ -0,0 +1,80 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "errors" + "fmt" + "sync" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +// Generic Client interface for supported provider clients +type Client interface { + DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, error) +} + +type ClientFactory struct { + creators map[string]ClientCreator // cache for client creation functions + clients map[string]Client // cache for provider clients + mu sync.Mutex +} + +func NewClientFactory() *ClientFactory { + return &ClientFactory{ + creators: make(map[string]ClientCreator), + clients: make(map[string]Client), + } +} + +// Basic client configuration +type ClientConfig struct { + Provider string + Project string + Location string +} + +// ClientCreator is a function type that will be defined on every provider in order to create its +// client +type ClientCreator func(config any) (Client, error) + +// Register adds the client creator function to a cache for later use +func (f *ClientFactory) Register(provider string, creator ClientCreator) { + if _, ok := f.creators[provider]; !ok { + f.creators[provider] = creator + } +} + +// CreateClient creates a client with the given configuration +// A [ClientCreator] must have been previously registered +func (f *ClientFactory) CreateClient(config *ClientConfig) (Client, error) { + if config == nil { + return nil, errors.New("empty client config") + } + + f.mu.Lock() + defer f.mu.Unlock() + + // every client will be identified by its provider-region combination + key := fmt.Sprintf("%s-%s", config.Provider, config.Location) + if client, ok := f.clients[key]; ok { + return client, nil // return from cache + } + + creator, ok := f.creators[config.Provider] + if !ok { + return nil, fmt.Errorf("unknown client type: %s", key) + } + + client, err := creator(config) + if err != nil { + return nil, err + } + + f.clients[key] = client + + return client, nil +} diff --git a/go/plugins/vertexai/modelgarden/modelgarden.go b/go/plugins/vertexai/modelgarden/modelgarden.go new file mode 100644 index 0000000000..f6041148f2 --- /dev/null +++ b/go/plugins/vertexai/modelgarden/modelgarden.go @@ -0,0 +1,89 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package modelgarden + +import ( + "context" + "fmt" + "os" + "sync" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden/anthropic" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden/client" +) + +// Config for Model Garden +type Config struct { + ProjectID string + Location string + Models []string +} + +var state struct { + initted bool + clients *client.ClientFactory // cache for all clients for available providers + projectID string + location string + mu sync.Mutex +} + +// Init initializes the ModelGarden plugin +// After calling Init, you may call [DefineModel] to create and register +// any additional generative models +func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) error { + if cfg == nil { + cfg = &Config{} + } + + state.mu.Lock() + defer state.mu.Unlock() + if state.initted { + panic("modelgarden.Init already called") + } + + state.projectID = cfg.ProjectID + if state.projectID == "" { + state.projectID = os.Getenv("GCLOUD_PROJECT") + } + if state.projectID == "" { + state.projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") + } + if state.projectID == "" { + return fmt.Errorf("modelgarden.Init: Model Garden requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") + } + + state.location = cfg.Location + if state.location == "" { + state.location = "us-central1" + } + + state.clients = client.NewClientFactory() + state.initted = true + for _, m := range cfg.Models { + // ANTHROPIC + if info, ok := anthropic.AnthropicModels[m]; ok { + state.clients.Register(anthropic.ProviderName, anthropic.Anthropic) + + anthropicClient, err := state.clients.CreateClient(&client.ClientConfig{ + Provider: anthropic.ProviderName, + Project: state.projectID, + Location: state.location, + }) + if err != nil { + return fmt.Errorf("unable to create client: %v", err) + } + + anthropicClient.DefineModel(g, m, &info) + continue + } + } + + return nil +} + +func Model(g *genkit.Genkit, provider string, name string) ai.Model { + return genkit.LookupModel(g, provider, name) +} diff --git a/go/plugins/vertexai/modelgarden/modelgarden_test.go b/go/plugins/vertexai/modelgarden/modelgarden_test.go new file mode 100644 index 0000000000..e3e70a2052 --- /dev/null +++ b/go/plugins/vertexai/modelgarden/modelgarden_test.go @@ -0,0 +1,227 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package modelgarden_test + +import ( + "context" + "encoding/base64" + "flag" + "io" + "net/http" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden/anthropic" +) + +var ( + projectID = flag.String("projectid", "", "Modelgarden project") + location = flag.String("location", "us-east5", "Geographic location") +) + +// go test . -v -projectid="my_projectId" +func TestModelGarden(t *testing.T) { + if *projectID == "" { + t.Skipf("no -projectid provided") + } + + ctx := context.Background() + g, err := genkit.New(nil) + if err != nil { + t.Fatal(err) + } + + err = modelgarden.Init(ctx, g, &modelgarden.Config{ + ProjectID: *projectID, + Location: *location, + Models: []string{"claude-3-7-sonnet"}, + }) + if err != nil { + t.Fatal(err) + } + + t.Run("invalid model", func(t *testing.T) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-not-valid-v2") + if m != nil { + t.Fatal("model should have been invalid") + } + }) + + t.Run("model version ok", func(t *testing.T) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-7-sonnet") + resp, err := genkit.Generate(ctx, g, + ai.WithConfig(&ai.GenerationCommonConfig{ + Temperature: 1, + Version: "claude-3-7-sonnet@20250219", + }), + ai.WithModel(m), + ai.WithSystemPrompt("talk to me like an evil pirate and say ARR several times but be very short"), + ai.WithMessages(ai.NewUserMessage(ai.NewTextPart("I'm a fish"))), + ) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(resp.Text(), "ARR") { + t.Fatalf("not a pirate :( :%s", resp.Text()) + } + }) + + t.Run("model version nok", func(t *testing.T) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-5-sonnet-v2") + _, err := genkit.Generate(ctx, g, + ai.WithConfig(&ai.GenerationCommonConfig{ + Temperature: 1, + Version: "foo", + }), + ai.WithModel(m), + ) + if err == nil { + t.Fatal("should have failed due wrong model version") + } + }) + + t.Run("media content", func(t *testing.T) { + i, err := fetchImgAsBase64() + if err != nil { + t.Fatal(err) + } + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-7-sonnet") + resp, err := genkit.Generate(ctx, g, + ai.WithSystemPrompt("You are a professional image detective that talks like an evil pirate that does not like tv shows, your task is to tell the name of the character in the image but be very short"), + ai.WithModel(m), + ai.WithMessages( + ai.NewUserMessage( + ai.NewTextPart("do you know who's in the image?"), + ai.NewMediaPart("", "data:image/png;base64,"+i)))) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(resp.Text(), "Bluey") { + t.Fatalf("it should've said Bluey but got: %s", resp.Text()) + } + }) + + t.Run("tools", func(t *testing.T) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-7-sonnet") + myJokeTool := genkit.DefineTool( + g, + "myJoke", + "When the user asks for a joke, this tool must be used to generate a joke, try to come up with a joke that uses the output of the tool", + func(ctx *ai.ToolContext, input *any) (string, error) { + return "why did the chicken cross the road?", nil + }, + ) + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithTextPrompt("tell me a joke"), + ai.WithTools(myJokeTool)) + if err != nil { + t.Fatal(err) + } + + if len(resp.Text()) == 0 { + t.Fatal("expected a response but nothing was returned") + } + }) + + t.Run("streaming", func(t *testing.T) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-7-sonnet") + out := "" + parts := 0 + + final, err := genkit.Generate(ctx, g, + ai.WithTextPrompt("Tell me a short story about a frog and a princess"), + ai.WithModel(m), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + parts++ + out += c.Content[0].Text + return nil + }), + ) + if err != nil { + t.Fatal(err) + } + + out2 := "" + for _, p := range final.Message.Content { + out2 += p.Text + } + + if out != out2 { + t.Fatalf("streaming and final should contain the same text.\nstreaming: %s\nfinal:%s\n", out, out2) + } + if final.Usage.InputTokens == 0 || final.Usage.OutputTokens == 0 { + t.Fatalf("empty usage stats: %#v", *final.Usage) + } + }) + + t.Run("tools streaming", func(t *testing.T) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-7-sonnet") + out := "" + parts := 0 + + myStoryTool := genkit.DefineTool( + g, + "myStory", + "When the user asks for a story, create a story about a frog and a fox that are good friends", + func(ctx *ai.ToolContext, input *any) (string, error) { + return "the fox is named Goph and the frog is called Fred", nil + }, + ) + + final, err := genkit.Generate(ctx, g, + ai.WithTextPrompt("Tell me a short story about a frog and a princess"), + ai.WithModel(m), + ai.WithTools(myStoryTool), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + parts++ + out += c.Content[0].Text + return nil + }), + ) + if err != nil { + t.Fatal(err) + } + + out2 := "" + for _, p := range final.Message.Content { + out2 += p.Text + } + + if out != out2 { + t.Fatalf("streaming and final should contain the same text.\nstreaming: %s\nfinal:%s\n", out, out2) + } + if final.Usage.InputTokens == 0 || final.Usage.OutputTokens == 0 { + t.Fatalf("empty usage stats: %#v", *final.Usage) + } + }) +} + +// Bluey rocks +func fetchImgAsBase64() (string, error) { + imgUrl := "https://www.bluey.tv/wp-content/uploads/2023/07/Bluey.png" + resp, err := http.Get(imgUrl) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", err + } + + // keep the img in memory + imageBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + base64string := base64.StdEncoding.EncodeToString(imageBytes) + return base64string, nil +} diff --git a/go/samples/basic-gemini/main.go b/go/samples/basic-gemini/main.go index daf2a67083..43d50a9d93 100644 --- a/go/samples/basic-gemini/main.go +++ b/go/samples/basic-gemini/main.go @@ -1,16 +1,5 @@ // Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// SPDX-License-Identifier: Apache-2.0 package main diff --git a/go/samples/model-garden/main.go b/go/samples/model-garden/main.go new file mode 100644 index 0000000000..62645b54dc --- /dev/null +++ b/go/samples/model-garden/main.go @@ -0,0 +1,59 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "errors" + "fmt" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden" + "github.com/firebase/genkit/go/plugins/vertexai/modelgarden/anthropic" +) + +func main() { + ctx := context.Background() + + g, err := genkit.New(nil) + if err != nil { + log.Fatal(err) + } + + cfg := &modelgarden.Config{ + Location: "us-east5", // or us-central1 + Models: []string{"claude-3-5-sonnet-v2", "claude-3-5-sonnet"}, + } + if err := modelgarden.Init(ctx, g, cfg); err != nil { + log.Fatal(err) + } + + // Define a simple flow that generates jokes about a given topic + genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { + m := modelgarden.Model(g, anthropic.ProviderName, "claude-3-5-sonnet-v2") + if m == nil { + return "", errors.New("jokesFlow: failed to find model") + } + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&ai.GenerationCommonConfig{ + Temperature: 0.1, + Version: "claude-3-5-sonnet-v2@20241022", + }), + ai.WithTextPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input))) + if err != nil { + return "", err + } + + text := resp.Text() + return text, nil + }) + + if err := g.Start(ctx, nil); err != nil { + log.Fatal(err) + } +}