Skip to content

Commit

Permalink
chore(gateway): pass models.ChatCompletion struct from handlers
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 28, 2024
1 parent 9ee6b61 commit 310e600
Show file tree
Hide file tree
Showing 18 changed files with 935 additions and 479 deletions.
21 changes: 18 additions & 3 deletions gateway/internal/api/v1/chatcompletions.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ func (s *V1Handler) ChatCompletions(
return nil, ErrRequiredHeaderNotExit
}

startTime := time.Now()
payload, err := json.Marshal(req.Msg)
chatRequest, err := s.createChatRequestSchema(req.Msg)
if err != nil {
return nil, errors.New(err)
}

startTime := time.Now()
rsvc := router.NewRouterService(routerConfig)

data := &llmv1.ChatCompletionResponse{}
Expand Down Expand Up @@ -69,7 +69,7 @@ func (s *V1Handler) ChatCompletions(
return nil, ErrChatCompletionNotSupported
}

resp, err := chatCompletionProvider.ChatCompletion(ctx, payload)
resp, err := chatCompletionProvider.ChatCompletion(ctx, chatRequest)
if err != nil {
return nil, errors.New(err)
}
Expand Down Expand Up @@ -97,3 +97,18 @@ func (s *V1Handler) ChatCompletions(

return connect.NewResponse(data), nil
}

func (s *V1Handler) createChatRequestSchema(req *llmv1.ChatCompletionRequest) (*models.ChatCompletion, error) {
payload, err := json.Marshal(req)
if err != nil {
return nil, err
}

data := &models.ChatCompletion{}
err = json.Unmarshal(payload, data)
if err != nil {
return nil, err
}

return data, nil
}
3 changes: 2 additions & 1 deletion gateway/internal/api/v1/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"context"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/models"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.ModelRequest]) (*connect.Response[llmv1.ModelResponse], error) {
allProviderModels := map[string]*llmv1.ProviderModels{}

for name := range models.ProviderRegistry {
for name := range base.ProviderRegistry {
provider, err := s.providerService.GetProvider(models.Connection{Name: name})
if err != nil {
continue
Expand Down
11 changes: 9 additions & 2 deletions gateway/internal/providers/anyscale/anyscale.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@ package anyscale
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
func (anyscale *anyscaleProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletion) (*http.Response, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err)
}

requestURL := fmt.Sprintf("%s%s", anyscale.config.BaseURL, anyscale.config.ChatCompletions)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions gateway/internal/providers/anyscale/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
//go:embed schema.json
var schema []byte

var _ base.IProvider = &anyscaleProvider{}
var _ base.ChatCompletionInterface = &anyscaleProvider{}

type anyscaleProvider struct {
info base.ProviderInfo
Expand Down Expand Up @@ -47,7 +47,7 @@ func getAnyscaleConfig(baseURL string) base.ProviderConfig {
}

func init() {
models.ProviderRegistry["anyscale"] = func(connection models.Connection) base.IProvider {
base.ProviderRegistry["anyscale"] = func(connection models.Connection) base.IProvider {
config := getAnyscaleConfig("https://api.endpoints.anyscale.com")
return &anyscaleProvider{
info: getAnyscaleInfo(),
Expand Down
8 changes: 7 additions & 1 deletion gateway/internal/providers/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package base
import (
"context"
"net/http"

"github.com/missingstudio/studio/backend/models"
)

type ProviderConfig struct {
Expand All @@ -22,7 +24,11 @@ type IProvider interface {
Schema() []byte
}

// ProviderRegistry holds all supported provider for which connections
// can be initialized
var ProviderRegistry = map[string]func(models.Connection) IProvider{}

type ChatCompletionInterface interface {
IProvider
ChatCompletion(context.Context, []byte) (*http.Response, error)
ChatCompletion(context.Context, *models.ChatCompletion) (*http.Response, error)
}
4 changes: 2 additions & 2 deletions gateway/internal/providers/deepinfra/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
//go:embed schema.json
var schema []byte

var _ base.IProvider = &deepinfraProvider{}
var _ base.ChatCompletionInterface = &deepinfraProvider{}

type deepinfraProvider struct {
info base.ProviderInfo
Expand Down Expand Up @@ -47,7 +47,7 @@ func getDeepinfraConfig(baseURL string) base.ProviderConfig {
}

func init() {
models.ProviderRegistry["deepinfra"] = func(connection models.Connection) base.IProvider {
base.ProviderRegistry["deepinfra"] = func(connection models.Connection) base.IProvider {
config := getDeepinfraConfig("https://api.deepinfra.com/v1/openai")
return &deepinfraProvider{
info: getDeepinfraInfo(),
Expand Down
11 changes: 9 additions & 2 deletions gateway/internal/providers/deepinfra/deepinfra.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@ package deepinfra
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
func (deepinfra *deepinfraProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletion) (*http.Response, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err)
}

requestURL := fmt.Sprintf("%s%s", deepinfra.config.BaseURL, deepinfra.config.ChatCompletions)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions gateway/internal/providers/openai/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
//go:embed schema.json
var schema []byte

var _ base.IProvider = &openAIProvider{}
var _ base.ChatCompletionInterface = &openAIProvider{}

type openAIProvider struct {
info base.ProviderInfo
Expand Down Expand Up @@ -46,7 +46,7 @@ func getOpenAIConfig(baseURL string) base.ProviderConfig {
}

func init() {
models.ProviderRegistry["openai"] = func(connection models.Connection) base.IProvider {
base.ProviderRegistry["openai"] = func(connection models.Connection) base.IProvider {
config := getOpenAIConfig("https://api.openai.com")
return &openAIProvider{
info: getOpenAIInfo(),
Expand Down
10 changes: 8 additions & 2 deletions gateway/internal/providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

Expand All @@ -26,11 +27,16 @@ var OpenAIModels = []string{
"gpt-3.5-turbo-instruct",
}

func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
func (oai *openAIProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletion) (*http.Response, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err)
}

requestURL := fmt.Sprintf("%s%s", oai.config.BaseURL, oai.config.ChatCompletions)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions gateway/internal/providers/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ func NewService() *Service {
func (s Service) GetProviders() map[string]base.IProvider {
providers := map[string]base.IProvider{}

for name, p := range models.ProviderRegistry {
for name, p := range base.ProviderRegistry {
providers[name] = p(models.Connection{})
}
return providers
}

func (s Service) GetProvider(conn models.Connection) (base.IProvider, error) {
if val, ok := models.ProviderRegistry[conn.Name]; ok {
if val, ok := base.ProviderRegistry[conn.Name]; ok {
return val(conn), nil
}
return nil, errors.New("unsupported connection")
Expand Down
4 changes: 2 additions & 2 deletions gateway/internal/providers/togetherai/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
//go:embed schema.json
var schema []byte

var _ base.IProvider = &togetherAIProvider{}
var _ base.ChatCompletionInterface = &togetherAIProvider{}

type togetherAIProvider struct {
info base.ProviderInfo
Expand Down Expand Up @@ -46,7 +46,7 @@ func getTogetherAIConfig(baseURL string) base.ProviderConfig {
}

func init() {
models.ProviderRegistry["togetherai"] = func(connection models.Connection) base.IProvider {
base.ProviderRegistry["togetherai"] = func(connection models.Connection) base.IProvider {
config := getTogetherAIConfig("https://api.together.xyz")
return &togetherAIProvider{
info: getTogetherAIInfo(),
Expand Down
11 changes: 9 additions & 2 deletions gateway/internal/providers/togetherai/togetherai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@ package togetherai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/backend/pkg/requester"
)

func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
func (ta *togetherAIProvider) ChatCompletion(ctx context.Context, payload *models.ChatCompletion) (*http.Response, error) {
client := requester.NewHTTPClient()

rawPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err)
}

requestURL := fmt.Sprintf("%s%s", ta.config.BaseURL, ta.config.ChatCompletions)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
if err != nil {
return nil, err
}
Expand Down
76 changes: 76 additions & 0 deletions gateway/models/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package models

type ResponseFormat struct {
Type string `json:"type,omitempty"`
}

type ChatCompletion struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Temperature float32 `json:"temperature,omitempty" default:"1"`
Suffix string `json:"suffix,omitempty"`
Seed uint64 `json:"seed,omitempty"`
N uint64 `json:"n,omitempty"`
Echo bool `json:"echo,omitempty"`
BestOf uint64 `json:"best_of,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Stream bool `json:"stream,omitempty"`
TopK float32 `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
Stop []string `json:"stop,omitempty"`
MaxTokens uint64 `json:"max_tokens,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogprobs uint64 `json:"top_logprobs,omitempty"`
LogitBias map[string]any `json:"logit_bias,omitempty"`
ToolChoice map[string]any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}

type Usage struct {
PromptToken int64 `json:"prompt_tokens,omitempty"`
CompletionTokens int64 `json:"completion_tokens,omitempty"`
TotalTokens int64 `json:"total_tokens,omitempty"`
}

type LogprobResult struct {
Tokens []string `json:"tokens"`
TokenLogprobs []float32 `json:"token_logprobs"`
TopLogprobs []map[string]float32 `json:"top_logprobs"`
TextOffset []int `json:"text_offset"`
}

type Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}

type ToolCallMessage struct {
Id string `json:"id"`
Type string `json:"type"`
Function Function `json:"function"`
}

type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
LogProbs map[string]any `json:"logprobs,omitempty"`
ToolCalls []ToolCallMessage `json:"tool_calls,omitempty"`
}

type CompletionChoice struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
LogProbs LogprobResult `json:"logprobs"`
}

type CompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []CompletionChoice `json:"choices"`
Usage Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}
9 changes: 0 additions & 9 deletions gateway/models/provider.go

This file was deleted.

10 changes: 10 additions & 0 deletions gateway/pkg/utils/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package utils

import "reflect"

func GetDefaultValue(ptr interface{}, defaultValue interface{}) interface{} {
if ptr == nil || reflect.ValueOf(ptr).IsNil() {
return defaultValue
}
return reflect.ValueOf(ptr).Elem().Interface()
}
Loading

0 comments on commit 310e600

Please sign in to comment.