Skip to content

Commit b86e9fc

Browse files
authored
feat: add azure embedding to ai-cache (#1975)
1 parent 2014234 commit b86e9fc

File tree

10 files changed

+230
-54
lines changed

10 files changed

+230
-54
lines changed

plugins/wasm-go/extensions/ai-cache/core.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC
130130
return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] no embedding provider configured for similarity search"))
131131
}
132132

133-
return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) {
133+
return activeEmbeddingProvider.GetEmbedding(key, ctx, func(textEmbedding []float64, err error) {
134134
log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err)
135135
if err != nil {
136136
handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error getting embedding for key: %s", PLUGIN_NAME, key), log)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package embedding
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"strings"
9+
10+
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
11+
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
12+
"github.com/tidwall/gjson"
13+
)
14+
15+
const (
16+
AZURE_PORT = 443
17+
AZURE_DEFAULT_MODEL_NAME = "text-embedding-ada-002"
18+
AZURE_ENDPOINT = "/openai/deployments/{model}/embeddings"
19+
)
20+
21+
type azureProviderInitializer struct {
22+
}
23+
24+
var azureConfig azureProviderConfig
25+
26+
type azureProviderConfig struct {
27+
// @Title zh-CN 文本特征提取服务 API Key
28+
// @Description zh-CN 文本特征提取服务 API Key
29+
apiKey string
30+
// @Title zh-CN 文本特征提取 api-version
31+
// @Description zh-CN 文本特征提取服务 api-version
32+
apiVersion string
33+
}
34+
35+
func (c *azureProviderInitializer) InitConfig(json gjson.Result) {
36+
azureConfig.apiKey = json.Get("apiKey").String()
37+
azureConfig.apiVersion = json.Get("apiVersion").String()
38+
}
39+
40+
func (c *azureProviderInitializer) ValidateConfig() error {
41+
if azureConfig.apiKey == "" {
42+
return errors.New("[Azure] apiKey is required")
43+
}
44+
if azureConfig.apiVersion == "" {
45+
return errors.New("[Azure] apiVersion is required")
46+
}
47+
return nil
48+
}
49+
50+
func (t *azureProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
51+
if c.servicePort == 0 {
52+
c.servicePort = AZURE_PORT
53+
}
54+
55+
if c.model == "" {
56+
c.model = AZURE_DEFAULT_MODEL_NAME
57+
}
58+
59+
return &AzureProvider{
60+
config: c,
61+
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
62+
FQDN: c.serviceName,
63+
Host: c.serviceHost,
64+
Port: c.servicePort,
65+
}),
66+
}, nil
67+
}
68+
69+
func (t *AzureProvider) GetProviderType() string {
70+
return PROVIDER_TYPE_AZURE
71+
}
72+
73+
type AzureProvider struct {
74+
config ProviderConfig
75+
client wrapper.HttpClient
76+
}
77+
78+
type AzureEmbeddingRequest struct {
79+
Input string `json:"input"`
80+
}
81+
82+
func (t *AzureProvider) constructParameters(text string) (string, [][2]string, []byte, error) {
83+
if text == "" {
84+
err := errors.New("queryString text cannot be empty")
85+
return "", nil, nil, err
86+
}
87+
88+
data := AzureEmbeddingRequest{
89+
Input: text,
90+
}
91+
92+
requestBody, err := json.Marshal(data)
93+
if err != nil {
94+
log.Errorf("failed to marshal request data: %v", err)
95+
return "", nil, nil, err
96+
}
97+
98+
model := t.config.model
99+
if model == "" {
100+
model = AZURE_DEFAULT_MODEL_NAME
101+
}
102+
103+
// 拼接 endpoint
104+
endpoint := strings.Replace(AZURE_ENDPOINT, "{model}", model, 1)
105+
endpoint = endpoint + "?" + "api-version=" + azureConfig.apiVersion
106+
107+
headers := [][2]string{
108+
{"api-key", azureConfig.apiKey},
109+
{"Content-Type", "application/json"},
110+
}
111+
112+
return endpoint, headers, requestBody, err
113+
}
114+
115+
type AzureEmbeddingResponse struct {
116+
Object string `json:"object"`
117+
Model string `json:"model"`
118+
Data []struct {
119+
Object string `json:"object"`
120+
Embedding []float64 `json:"embedding"`
121+
Index int `json:"index"`
122+
} `json:"data"`
123+
}
124+
125+
func (t *AzureProvider) parseTextEmbedding(responseBody []byte) (*AzureEmbeddingResponse, error) {
126+
var resp AzureEmbeddingResponse
127+
if err := json.Unmarshal(responseBody, &resp); err != nil {
128+
return nil, fmt.Errorf("failed to parse response: %w", err)
129+
}
130+
return &resp, nil
131+
}
132+
133+
func (t *AzureProvider) GetEmbedding(
134+
queryString string,
135+
ctx wrapper.HttpContext,
136+
callback func(emb []float64, err error)) error {
137+
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString)
138+
if err != nil {
139+
log.Errorf("failed to construct parameters: %v", err)
140+
return err
141+
}
142+
143+
var resp *AzureEmbeddingResponse
144+
err = t.client.Post(embUrl, embHeaders, embRequestBody,
145+
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
146+
147+
if statusCode != http.StatusOK {
148+
err = fmt.Errorf("failed to get embedding due to status code: %d, resp: %s", statusCode, responseBody)
149+
callback(nil, err)
150+
return
151+
}
152+
153+
resp, err = t.parseTextEmbedding(responseBody)
154+
if err != nil {
155+
err = fmt.Errorf("failed to parse response: %v", err)
156+
callback(nil, err)
157+
return
158+
}
159+
160+
log.Debugf("get embedding response: %d, %s", statusCode, responseBody)
161+
162+
if len(resp.Data) == 0 {
163+
err = errors.New("no embedding found in response")
164+
callback(nil, err)
165+
return
166+
}
167+
168+
callback(resp.Data[0].Embedding, nil)
169+
170+
}, t.config.timeout)
171+
return err
172+
}

plugins/wasm-go/extensions/ai-cache/embedding/cohere.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"strconv"
99

10+
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
1011
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
1112
"github.com/tidwall/gjson"
1213
)
@@ -79,7 +80,7 @@ type CohereProvider struct {
7980
func (t *CohereProvider) GetProviderType() string {
8081
return PROVIDER_TYPE_COHERE
8182
}
82-
func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
83+
func (t *CohereProvider) constructParameters(texts []string) (string, [][2]string, []byte, error) {
8384
model := t.config.model
8485

8586
if model == "" {
@@ -118,9 +119,8 @@ func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereRespons
118119
func (t *CohereProvider) GetEmbedding(
119120
queryString string,
120121
ctx wrapper.HttpContext,
121-
log wrapper.Log,
122122
callback func(emb []float64, err error)) error {
123-
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
123+
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString})
124124
if err != nil {
125125
log.Errorf("failed to construct parameters: %v", err)
126126
return err

plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"strconv"
99

10+
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
1011
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
1112
"github.com/tidwall/gjson"
1213
)
@@ -103,7 +104,7 @@ type DSProvider struct {
103104
client wrapper.HttpClient
104105
}
105106

106-
func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
107+
func (d *DSProvider) constructParameters(texts []string) (string, [][2]string, []byte, error) {
107108

108109
model := d.config.model
109110

@@ -159,9 +160,8 @@ func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error)
159160
func (d *DSProvider) GetEmbedding(
160161
queryString string,
161162
ctx wrapper.HttpContext,
162-
log wrapper.Log,
163163
callback func(emb []float64, err error)) error {
164-
embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log)
164+
embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString})
165165
if err != nil {
166166
log.Errorf("failed to construct parameters: %v", err)
167167
return err

plugins/wasm-go/extensions/ai-cache/embedding/huggingface.go

+15-14
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7-
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
8-
"github.com/tidwall/gjson"
97
"net/http"
108
"strconv"
119
"strings"
10+
11+
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
12+
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
13+
"github.com/tidwall/gjson"
1214
)
1315

1416
const (
@@ -18,29 +20,29 @@ const (
1820
HUGGINGFACE_ENDPOINT = "/pipeline/feature-extraction/{modelId}"
1921
)
2022

21-
type HuggingFaceProviderInitializer struct {
23+
type huggingfaceProviderInitializer struct {
2224
}
2325

24-
var HuggingFaceConfig HuggingFaceProviderConfig
26+
var huggingfaceConfig huggingfaceProviderConfig
2527

26-
type HuggingFaceProviderConfig struct {
28+
type huggingfaceProviderConfig struct {
2729
// @Title zh-CN 文本特征提取服务 API Key
2830
// @Description zh-CN 文本特征提取服务 API Key。在HuggingFace定义为 hf_token
2931
apiKey string
3032
}
3133

32-
func (c *HuggingFaceProviderInitializer) InitConfig(json gjson.Result) {
33-
HuggingFaceConfig.apiKey = json.Get("apiKey").String()
34+
func (c *huggingfaceProviderInitializer) InitConfig(json gjson.Result) {
35+
huggingfaceConfig.apiKey = json.Get("apiKey").String()
3436
}
3537

36-
func (c *HuggingFaceProviderInitializer) ValidateConfig() error {
37-
if HuggingFaceConfig.apiKey == "" {
38+
func (c *huggingfaceProviderInitializer) ValidateConfig() error {
39+
if huggingfaceConfig.apiKey == "" {
3840
return errors.New("[HuggingFace] hfTokens is required")
3941
}
4042
return nil
4143
}
4244

43-
func (t *HuggingFaceProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
45+
func (t *huggingfaceProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
4446
if c.servicePort == 0 {
4547
c.servicePort = HUGGINGFACE_PORT
4648
}
@@ -78,7 +80,7 @@ type HuggingFaceEmbeddingRequest struct {
7880
} `json:"options"`
7981
}
8082

81-
func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
83+
func (t *HuggingFaceProvider) constructParameters(text string) (string, [][2]string, []byte, error) {
8284
if text == "" {
8385
err := errors.New("queryString text cannot be empty")
8486
return "", nil, nil, err
@@ -108,7 +110,7 @@ func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log)
108110
endpoint := strings.Replace(HUGGINGFACE_ENDPOINT, "{modelId}", modelId, 1)
109111

110112
headers := [][2]string{
111-
{"Authorization", "Bearer " + HuggingFaceConfig.apiKey},
113+
{"Authorization", "Bearer " + huggingfaceConfig.apiKey},
112114
{"Content-Type", "application/json"},
113115
}
114116

@@ -127,9 +129,8 @@ func (t *HuggingFaceProvider) parseTextEmbedding(responseBody []byte) ([]float64
127129
func (t *HuggingFaceProvider) GetEmbedding(
128130
queryString string,
129131
ctx wrapper.HttpContext,
130-
log wrapper.Log,
131132
callback func(emb []float64, err error)) error {
132-
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
133+
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString)
133134
if err != nil {
134135
log.Errorf("failed to construct parameters: %v", err)
135136
return err

plugins/wasm-go/extensions/ai-cache/embedding/ollama.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7-
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
8-
"github.com/tidwall/gjson"
97
"net/http"
108
"strconv"
9+
10+
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
11+
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
12+
"github.com/tidwall/gjson"
1113
)
1214

1315
const (
@@ -69,7 +71,7 @@ type ollamaEmbeddingRequest struct {
6971
Model string `json:"model"`
7072
}
7173

72-
func (t *ollamaProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
74+
func (t *ollamaProvider) constructParameters(text string) (string, [][2]string, []byte, error) {
7375
if text == "" {
7476
err := errors.New("queryString text cannot be empty")
7577
return "", nil, nil, err
@@ -105,9 +107,8 @@ func (t *ollamaProvider) parseTextEmbedding(responseBody []byte) (*ollamaRespons
105107
func (t *ollamaProvider) GetEmbedding(
106108
queryString string,
107109
ctx wrapper.HttpContext,
108-
log wrapper.Log,
109110
callback func(emb []float64, err error)) error {
110-
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
111+
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString)
111112
if err != nil {
112113
log.Errorf("failed to construct parameters: %v", err)
113114
return err

plugins/wasm-go/extensions/ai-cache/embedding/openai.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net/http"
88

9+
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
910
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
1011
"github.com/tidwall/gjson"
1112
)
@@ -93,7 +94,7 @@ type OpenAIProvider struct {
9394
client wrapper.HttpClient
9495
}
9596

96-
func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
97+
func (t *OpenAIProvider) constructParameters(text string) (string, [][2]string, []byte, error) {
9798
if text == "" {
9899
err := errors.New("queryString text cannot be empty")
99100
return "", nil, nil, err
@@ -130,9 +131,8 @@ func (t *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIRespons
130131
func (t *OpenAIProvider) GetEmbedding(
131132
queryString string,
132133
ctx wrapper.HttpContext,
133-
log wrapper.Log,
134134
callback func(emb []float64, err error)) error {
135-
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
135+
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString)
136136
if err != nil {
137137
log.Errorf("failed to construct parameters: %v", err)
138138
return err

0 commit comments

Comments
 (0)