Skip to content

Commit

Permalink
Merge pull request #421 from symflower/allow-model-parameters
Browse files Browse the repository at this point in the history
fix, Allow selecting models with attributes for openRouter as well
  • Loading branch information
zimmski authored Feb 20, 2025
2 parents 99f5feb + 3291eac commit 377f295
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 7 deletions.
21 changes: 15 additions & 6 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
}

modelID, _ := model.ParseModelID(modelIDsWithAttributes)
modelID, attributes := model.ParseModelID(modelIDsWithAttributes)
modelIDWithProvider := providerID + provider.ProviderModelSeparator + modelID

p, ok := providers[providerID]
if !ok {
Expand All @@ -460,18 +461,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}

// TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
if _, ok := models[modelIDsWithProviderAndAttributes]; !ok {
if _, ok := models[modelIDWithProvider]; !ok {
ms, err := p.Models()
if err != nil {
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
}
for _, m := range ms {
if _, ok := models[m.ID()]; ok {
if _, ok := models[m.ModelID()]; ok {
continue
}

models[m.ID()] = m
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
models[m.ModelID()] = m
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ModelID())
}
modelIDs = maps.Keys(models)
sort.Strings(modelIDs)
Expand All @@ -489,10 +490,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
pc.AddModel(m)
} else {
var ok bool
m, ok = models[modelIDsWithProviderAndAttributes]
m, ok = models[modelIDWithProvider]
if !ok {
command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", "))
}

// If a model with attributes is requested, we add the base model plus attributes as new model to our list.
if len(attributes) > 0 {
modelWithAttributes := m.Clone()
modelWithAttributes.SetAttributes(attributes)
models[modelWithAttributes.ID()] = modelWithAttributes
m = modelWithAttributes
}
}
evaluationContext.Models = append(evaluationContext.Models, m)
evaluationContext.ProviderForModel[m] = p
Expand Down
38 changes: 38 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,44 @@ func TestEvaluateInitialize(t *testing.T) {
}, config.Repositories.Selected)
},
})
validate(t, &testCase{
Name: "Model with attributes",

Command: makeValidCommand(func(command *Evaluate) {
command.ModelIDsWithProviderAndAttributes = []string{
"openrouter/openai/o3-mini@reasoning_effort=low",
"openrouter/openai/o3-mini@reasoning_effort=high",
}
command.ProviderTokens = map[string]string{
"openrouter": "fake-token",
}
}),

ValidateContext: func(t *testing.T, context *evaluate.Context) {
assert.Len(t, context.Models, 2)

assert.Equal(t, "openrouter/openai/o3-mini@reasoning_effort=high", context.Models[0].ID())
assert.Equal(t, "openrouter/openai/o3-mini", context.Models[0].ModelID())
expectedAttributes := map[string]string{
"reasoning_effort": "high",
}
assert.Equal(t, expectedAttributes, context.Models[0].Attributes())

assert.Equal(t, "openrouter/openai/o3-mini@reasoning_effort=low", context.Models[1].ID())
assert.Equal(t, "openrouter/openai/o3-mini", context.Models[1].ModelID())
expectedAttributes = map[string]string{
"reasoning_effort": "low",
}
assert.Equal(t, expectedAttributes, context.Models[1].Attributes())
},
ValidateConfiguration: func(t *testing.T, config *EvaluationConfiguration) {
expectedSelected := []string{
"openrouter/openai/o3-mini@reasoning_effort=high",
"openrouter/openai/o3-mini@reasoning_effort=low",
}
assert.Equal(t, expectedSelected, config.Models.Selected)
},
})
validate(t, &testCase{
Name: "Local runtime does not allow parallel parameter",

Expand Down
19 changes: 18 additions & 1 deletion model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ var _ model.Model = (*Model)(nil)

// ID returns full identifier, including the provider and attributes.
func (m *Model) ID() (id string) {
return m.id
attributeString := ""
for key, value := range m.attributes {
attributeString += "@" + key + "=" + value
}

return m.id + attributeString
}

// ModelID returns the unique identifier of this model with its provider.
Expand All @@ -93,11 +98,23 @@ func (m *Model) Attributes() (attributes map[string]string) {
return m.attributes
}

// SetAttributes sets the given attributes.
func (m *Model) SetAttributes(attributes map[string]string) {
m.attributes = attributes
}

// MetaInformation returns the meta information of a model.
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
return m.metaInformation
}

// Clone returns a copy of the model.
func (m *Model) Clone() (clone model.Model) {
model := *m

return &model
}

// llmSourceFilePromptContext is the base template context for an LLM generation prompt.
type llmSourceFilePromptContext struct {
// Language holds the programming language name.
Expand Down
5 changes: 5 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ type Model interface {

// Attributes returns query attributes.
Attributes() (attributes map[string]string)
// SetAttributes sets the given attributes.
SetAttributes(attributes map[string]string)

// MetaInformation returns the meta information of a model.
MetaInformation() *MetaInformation

// Clone returns a copy of the model.
Clone() (clone Model)
}

// ParseModelID takes a packaged model ID with optional attributes and converts it into its model ID and optional attributes.
Expand Down
11 changes: 11 additions & 0 deletions model/symflower/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,22 @@ func (m *Model) Attributes() (attributes map[string]string) {
return nil
}

// SetAttributes sets the given attributes.
func (m *Model) SetAttributes(attributes map[string]string) {
}

// MetaInformation returns the meta information of a model.
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
return nil
}

// Clone returns a copy of the model.
func (m *Model) Clone() (clone model.Model) {
model := *m

return &model
}

var _ model.CapabilityWriteTests = (*Model)(nil)

// WriteTests generates test files for the given implementation file in a repository.
Expand Down
25 changes: 25 additions & 0 deletions model/testing/Model_mock_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 377f295

Please sign in to comment.