diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index afd810441..e61c215cd 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -321,6 +321,11 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. } else { for _, model := range command.Models { p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0] + + if _, ok := providersSelected[p]; ok { + continue + } + if provider, ok := provider.Providers[p]; !ok { command.logger.Panicf("Provider %q does not exist", p) } else { @@ -361,6 +366,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. }() } + // Check if a provider has the ability to pull models and do so if necessary. + if puller, ok := p.(provider.Puller); ok { + command.logger.Printf("Pulling available models for provider %q", p.ID()) + for _, modelID := range command.Models { + if strings.HasPrefix(modelID, p.ID()) { + if err := puller.Pull(command.logger, modelID); err != nil { + command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err) + } + } + } + } + ms, err := p.Models() if err != nil { command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)