Skip to content

Commit

Permalink
feat(gateway): add prompt registry crud operation
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 29, 2024
1 parent 9f79b0e commit 45c1d03
Show file tree
Hide file tree
Showing 35 changed files with 5,303 additions and 3,168 deletions.
6 changes: 5 additions & 1 deletion gateway/cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/missingstudio/studio/backend/internal/api"
"github.com/missingstudio/studio/backend/internal/connections"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
"github.com/missingstudio/studio/backend/internal/server"
Expand Down Expand Up @@ -64,8 +65,11 @@ func Serve(cfg *config.Config) error {
connectionRepository := postgres.NewConnectionRepository(dbc)
connectionService := connections.NewService(connectionRepository)

promptRepository := postgres.NewPromptRepository(dbc)
promptService := prompt.NewService(promptRepository)

providerService := providers.NewService()
deps := api.NewDeps(logger, ingester, rl, providerService, connectionService)
deps := api.NewDeps(logger, ingester, rl, providerService, connectionService, promptService)

if err := server.Serve(ctx, logger, cfg.App, deps); err != nil {
logger.Error("error starting server", "error", err)
Expand Down
4 changes: 4 additions & 0 deletions gateway/internal/api/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/missingstudio/studio/backend/internal/connections"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/ratelimiter"
)
Expand All @@ -15,6 +16,7 @@ type Deps struct {
RateLimiter *ratelimiter.RateLimiter
ProviderService *providers.Service
ConnectionService *connections.Service
PromptService *prompt.Service
}

func NewDeps(
Expand All @@ -23,12 +25,14 @@ func NewDeps(
ratelimiter *ratelimiter.RateLimiter,
ps *providers.Service,
cs *connections.Service,
pms *prompt.Service,
) *Deps {
return &Deps{
Logger: logger,
Ingester: ingester,
RateLimiter: ratelimiter,
ProviderService: ps,
ConnectionService: cs,
PromptService: pms,
}
}
2 changes: 1 addition & 1 deletion gateway/internal/api/v1/chatcompletions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/missingstudio/studio/backend/internal/router"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1"
)

var (
Expand Down
2 changes: 1 addition & 1 deletion gateway/internal/api/v1/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

"connectrpc.com/connect"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
)
Expand Down
2 changes: 1 addition & 1 deletion gateway/internal/api/v1/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/models"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1"
)

func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.ModelRequest]) (*connect.Response[llmv1.ModelResponse], error) {
Expand Down
104 changes: 104 additions & 0 deletions gateway/internal/api/v1/prompts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package v1

import (
"context"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/common/errors"
promptv1 "github.com/missingstudio/studio/protos/pkg/prompt/v1"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
)

func (s *V1Handler) ListPrompts(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[promptv1.ListPromptsResponse], error) {
prompts, err := s.promptService.GetAll(ctx)
if err != nil {
return nil, errors.NewInternalError(err.Error())
}

data := []*promptv1.Prompt{}
for _, p := range prompts {
pmetadata, _ := structpb.NewStruct(p.Metadata)

data = append(data, &promptv1.Prompt{
Id: p.ID.String(),
Name: p.Name,
Description: p.Description,
Template: p.Template,
Metadata: pmetadata,
})
}

return connect.NewResponse(&promptv1.ListPromptsResponse{
Prompt: data,
}), nil
}

func (s *V1Handler) CreatePrompt(ctx context.Context, req *connect.Request[promptv1.CreatePromptRequest]) (*connect.Response[promptv1.CreatePromptResponse], error) {
prompt := models.Prompt{
Name: req.Msg.Name,
Description: req.Msg.Description,
Template: req.Msg.Template,
Metadata: req.Msg.Metadata.AsMap(),
}

prompt, err := s.promptService.Upsert(ctx, prompt)
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

stMetadata, err := structpb.NewStruct(prompt.Metadata)
if err != nil {
return nil, errors.NewInternalError(err.Error())
}

return connect.NewResponse(&promptv1.CreatePromptResponse{
Name: prompt.Name,
Description: prompt.Description,
Template: prompt.Template,
Metadata: stMetadata,
}), nil
}

func (s *V1Handler) GetPrompt(ctx context.Context, req *connect.Request[promptv1.GetPromptRequest]) (*connect.Response[promptv1.GetPromptResponse], error) {
prompt, err := s.promptService.GetByName(ctx, req.Msg.Name)
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

stMetadata, err := structpb.NewStruct(prompt.Metadata)
if err != nil {
return nil, errors.NewInternalError(err.Error())
}

p := &promptv1.Prompt{
Id: prompt.ID.String(),
Name: prompt.Name,
Description: prompt.Description,
Template: prompt.Template,
Metadata: stMetadata,
}

return connect.NewResponse(&promptv1.GetPromptResponse{
Prompt: p,
}), nil
}

func (s *V1Handler) GetPromptValue(ctx context.Context, req *connect.Request[promptv1.GetPromptValueRequest]) (*connect.Response[promptv1.GetPromptValueResponse], error) {
p, err := s.promptService.GetByName(ctx, req.Msg.Name)
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

prompt := prompt.NewPrompt(p.Template, req.Msg.Data.AsMap())
value, err := prompt.Run()
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

return connect.NewResponse(&promptv1.GetPromptValueResponse{
Data: value,
}), nil
}
2 changes: 1 addition & 1 deletion gateway/internal/api/v1/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/jeremywohl/flatten"
"github.com/missingstudio/studio/backend/models"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1"
"github.com/xeipuuv/gojsonschema"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
Expand Down
12 changes: 11 additions & 1 deletion gateway/internal/api/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@ import (
"github.com/missingstudio/studio/backend/internal/connections"
"github.com/missingstudio/studio/backend/internal/ingester"
"github.com/missingstudio/studio/backend/internal/interceptor"
"github.com/missingstudio/studio/backend/internal/prompt"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
"github.com/missingstudio/studio/protos/pkg/llm/v1/llmv1connect"
"github.com/missingstudio/studio/protos/pkg/prompt/v1/promptv1connect"
)

type V1Handler struct {
llmv1connect.UnimplementedLLMServiceHandler
promptv1connect.UnimplementedPromptRegistryServiceHandler
ingester ingester.Ingester
providerService *providers.Service
connectionService *connections.Service
promptService *prompt.Service
}

func NewHandlerV1(d *api.Deps) *V1Handler {
return &V1Handler{
ingester: d.Ingester,
providerService: d.ProviderService,
connectionService: d.ConnectionService,
promptService: d.PromptService,
}
}

Expand Down Expand Up @@ -59,6 +64,11 @@ func Register(d *api.Deps) (http.Handler, error) {
compress1KB,
connect.WithInterceptors(stdInterceptors...),
)),
vanguard.NewService(promptv1connect.NewPromptRegistryServiceHandler(
v1Handler,
compress1KB,
connect.WithInterceptors(stdInterceptors...),
)),
}
transcoderOptions := []vanguard.TranscoderOption{
vanguard.WithUnknownHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
Expand Down
14 changes: 10 additions & 4 deletions gateway/internal/connectrpc/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import (

"github.com/missingstudio/studio/backend/internal/api"
v1 "github.com/missingstudio/studio/backend/internal/api/v1"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
"github.com/missingstudio/studio/protos/pkg/llm/v1/llmv1connect"
"github.com/missingstudio/studio/protos/pkg/prompt/v1/promptv1connect"
)

func NewConnectMux(d *api.Deps) (*http.ServeMux, error) {
Expand All @@ -28,13 +29,18 @@ func NewConnectMux(d *api.Deps) (*http.ServeMux, error) {
}))

mux.Handle(grpchealth.NewHandler(
grpchealth.NewStaticChecker(llmv1connect.LLMServiceName),
grpchealth.NewStaticChecker(
llmv1connect.LLMServiceName,
promptv1connect.PromptRegistryServiceName,
),
compress1KB,
))

reflector := grpcreflect.NewStaticReflector(llmv1connect.LLMServiceName)
reflector := grpcreflect.NewStaticReflector(
llmv1connect.LLMServiceName,
promptv1connect.PromptRegistryServiceName,
)
mux.Handle(grpcreflect.NewHandlerV1(reflector, compress1KB))
mux.Handle(grpcreflect.NewHandlerV1Alpha(reflector, compress1KB))

return mux, nil
}
39 changes: 39 additions & 0 deletions gateway/internal/prompt/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package prompt

import (
"bytes"
"context"
"html/template"

"github.com/gofrs/uuid"
"github.com/missingstudio/studio/backend/models"
)

type Repository interface {
GetAll(context.Context) ([]models.Prompt, error)
Upsert(context.Context, models.Prompt) (models.Prompt, error)
GetByID(context.Context, uuid.UUID) (models.Prompt, error)
GetByName(context.Context, string) (models.Prompt, error)
DeleteByID(context.Context, uuid.UUID) error
}

type Prompt struct {
tmpl *template.Template
data map[string]any
}

func NewPrompt(text string, data map[string]any) *Prompt {
return &Prompt{
tmpl: template.Must(template.New("prompt").Parse(text)),
data: data,
}
}

func (p *Prompt) Run() (string, error) {
var buf bytes.Buffer
err := p.tmpl.Execute(&buf, p.data)
if err != nil {
return "", err
}
return buf.String(), nil
}
60 changes: 60 additions & 0 deletions gateway/internal/prompt/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package prompt

import (
"context"
"fmt"

"github.com/gofrs/uuid"
"github.com/missingstudio/studio/backend/models"
)

var _ Repository = &Service{}

type Service struct {
promptRepo Repository
}

func NewService(promptRepo Repository) *Service {
return &Service{
promptRepo: promptRepo,
}
}

func (s *Service) DeleteByID(ctx context.Context, promptID uuid.UUID) error {
return s.promptRepo.DeleteByID(ctx, promptID)
}

func (s *Service) GetAll(ctx context.Context) ([]models.Prompt, error) {
prompts, err := s.promptRepo.GetAll(ctx)
if err != nil {
return nil, err
}
return prompts, nil
}

func (s *Service) GetByID(ctx context.Context, promptID uuid.UUID) (models.Prompt, error) {
prompt, err := s.promptRepo.GetByID(ctx, promptID)
if err != nil {
return models.Prompt{}, err
}

return prompt, err
}

func (s *Service) GetByName(ctx context.Context, name string) (models.Prompt, error) {
prompt, err := s.promptRepo.GetByName(ctx, name)
if err != nil {
return models.Prompt{}, err
}

return prompt, err
}

func (s *Service) Upsert(ctx context.Context, c models.Prompt) (models.Prompt, error) {
id, err := s.promptRepo.Upsert(ctx, c)
if err != nil {
return models.Prompt{}, fmt.Errorf("failed to save prompt: %w", err)
}

return id, err
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS prompts;
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";

CREATE TABLE prompts (
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
name text UNIQUE NOT NULL,
description text,
template text,
metadata jsonb,
created_at timestamp DEFAULT NOW(),
updated_at timestamp DEFAULT NOW()
);
1 change: 1 addition & 0 deletions gateway/internal/storage/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var (

const (
TABLE_CONNECTIONS = "connections"
TABLE_PROMPTS = "prompts"
)

type Encryptor interface {
Expand Down
Loading

0 comments on commit 45c1d03

Please sign in to comment.