Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add LM Studio compatibility #1302

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/plugin_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/danielmiessler/fabric/plugins/ai/dryrun"
"github.com/danielmiessler/fabric/plugins/ai/gemini"
"github.com/danielmiessler/fabric/plugins/ai/groq"
"github.com/danielmiessler/fabric/plugins/ai/lmstudio"
"github.com/danielmiessler/fabric/plugins/ai/mistral"
"github.com/danielmiessler/fabric/plugins/ai/ollama"
"github.com/danielmiessler/fabric/plugins/ai/openai"
Expand Down Expand Up @@ -54,7 +55,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) {
gemini.NewClient(),
//gemini_openai.NewClient(),
anthropic.NewClient(), siliconcloud.NewClient(),
openrouter.NewClient(), mistral.NewClient(), deepseek.NewClient())
openrouter.NewClient(), lmstudio.NewClient(), mistral.NewClient(), deepseek.NewClient())
_ = ret.Configure()

return
Expand Down
358 changes: 358 additions & 0 deletions plugins/ai/lmstudio/lmstudio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
package lmstudio

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

goopenai "github.com/sashabaranov/go-openai"

"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
)

// NewClient creates a new LM Studio client with default configuration.
func NewClient() (ret *Client) {
return NewClientCompatible("LM Studio", "http://localhost:1234/v1", nil)
}

// NewClientCompatible creates a new LM Studio client with custom configuration.
func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCustom func() error) (ret *Client) {
ret = &Client{}

if configureCustom == nil {
configureCustom = ret.configure
}
ret.PluginBase = &plugins.PluginBase{
Name: vendorName,
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
ConfigureCustom: configureCustom,
}
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
ret.ApiBaseURL.Value = defaultBaseUrl
return
}

// Client represents the LM Studio client.
type Client struct {
*plugins.PluginBase
ApiBaseURL *plugins.SetupQuestion
HttpClient *http.Client
}

// configure sets up the HTTP client.
func (c *Client) configure() error {
c.HttpClient = &http.Client{}
return nil
}

// Configure sets up the client configuration.
func (c *Client) Configure() error {
return c.ConfigureCustom()
}

// ListModels returns a list of available models.
func (c *Client) ListModels() ([]string, error) {
url := fmt.Sprintf("%s/models", c.ApiBaseURL.Value)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}

if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

models := make([]string, len(result.Data))
for i, model := range result.Data {
models[i] = model.ID
}

return models, nil
}

// // SendStream sends a stream of messages (not implemented for LM Studio).
// func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
// return fmt.Errorf("streaming is not currently supported for LM Studio")
// }

func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
url := fmt.Sprintf("%s/chat/completions", c.ApiBaseURL.Value)

payload := map[string]interface{}{
"messages": msgs,
"model": opts.Model,
"stream": true, // Enable streaming
}

jsonPayload, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}

req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

resp, err := c.HttpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

// Close channel when function exits
defer close(channel)

reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("error reading response: %w", err)
}

// Ignore empty lines
if len(line) == 0 {
continue
}

// Remove OpenAI-style prefix
if bytes.HasPrefix(line, []byte("data: ")) {
line = bytes.TrimPrefix(line, []byte("data: "))
}

// Handle [DONE] signal
if string(line) == "[DONE]" {
break
}

// Parse JSON response
var result map[string]interface{}
if err := json.Unmarshal(line, &result); err != nil {
continue
}

// Extract content from streaming chunks
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}

delta, ok := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
if !ok {
continue
}

content, _ := delta["content"].(string)

// Send data to channel
channel <- content
}

return nil
}

// Send sends a single message and returns the response.
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
url := fmt.Sprintf("%s/chat/completions", c.ApiBaseURL.Value)

payload := map[string]interface{}{
"messages": msgs,
"model": opts.Model,
// Add other options from opts if supported by LM Studio
}

jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

resp, err := c.HttpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}

choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return "", fmt.Errorf("invalid response format: missing or empty choices")
}

message, ok := choices[0].(map[string]interface{})["message"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid response format: missing message in first choice")
}

content, ok := message["content"].(string)
if !ok {
return "", fmt.Errorf("invalid response format: missing or non-string content in message")
}

return content, nil
}

// Complete sends a completion request and returns the response.
func (c *Client) Complete(ctx context.Context, prompt string, opts *common.ChatOptions) (string, error) {
url := fmt.Sprintf("%s/completions", c.ApiBaseURL.Value)

payload := map[string]interface{}{
"prompt": prompt,
"model": opts.Model,
// Add other options from opts if supported by LM Studio
}

jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

resp, err := c.HttpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}

choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return "", fmt.Errorf("invalid response format: missing or empty choices")
}

text, ok := choices[0].(map[string]interface{})["text"].(string)
if !ok {
return "", fmt.Errorf("invalid response format: missing or non-string text in first choice")
}

return text, nil
}

// GetEmbeddings returns embeddings for the given input.
func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *common.ChatOptions) ([]float64, error) {
url := fmt.Sprintf("%s/embeddings", c.ApiBaseURL.Value)

payload := map[string]interface{}{
"input": input,
"model": opts.Model,
// Add other options from opts if supported by LM Studio
}

jsonPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var result struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}

if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

if len(result.Data) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}

return result.Data[0].Embedding, nil
}

// GetName returns the name of the vendor.
func (c *Client) GetName() string {
return c.Name
}

// IsConfigured checks if the client is configured.
func (c *Client) IsConfigured() bool {
return c.ApiBaseURL != nil && c.ApiBaseURL.Value != ""
}

// Setup performs any necessary setup for the client.
func (c *Client) Setup() error {
return c.Configure()
}

// SetupFillEnvFileContent fills the environment file content.
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
envName := fmt.Sprintf("%s_API_BASE_URL", c.EnvNamePrefix)
buffer.WriteString(fmt.Sprintf("%s=%s\n", envName, c.ApiBaseURL.Value))
}
Loading
Loading