Skip to content

Commit

Permalink
Added support for KernelMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
durmisi committed Mar 17, 2024
1 parent b99541b commit 19309e5
Show file tree
Hide file tree
Showing 29 changed files with 298 additions and 112 deletions.
10 changes: 2 additions & 8 deletions Core/AIBroker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@

namespace OpenAIExtensions
{

public interface IAIBroker
{
public OpenAIClient GetClient();
public OpenAIClient GetClient(string endpoint, string? key);

public OpenAIClient GetClient(string endpoint, string? key);
}

public class AIBroker : IAIBroker
{

private readonly IConfiguration? _configuration;
private readonly string? _endpoint;
private readonly string? _key;
Expand All @@ -41,13 +39,10 @@ public AIBroker(string endpoint, string key)
_key = key;
}


public OpenAIClient GetClient()
{

if (_configuration != null)
{

var endpoint = _configuration.GetValue<string>("OpenAI:Endpoint");

if (string.IsNullOrEmpty(endpoint))
Expand Down Expand Up @@ -81,6 +76,5 @@ public OpenAIClient GetClient(string endpoint, string? key)
new Uri(endpoint),
new DefaultAzureCredential());
}

}
}
}
10 changes: 3 additions & 7 deletions Core/Chats/AIConversationManager.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
Expand Down Expand Up @@ -34,21 +35,19 @@ public AIConversationManager(Kernel kernel,
OpenAIPromptExecutionSettings? executionSettings = null,
CancellationToken ct = default)
{

ArgumentNullException.ThrowIfNull(chatHistory);

OpenAIPromptExecutionSettings openAIPromptExecutionSettings = new()
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};


if (executionSettings != null)
{
openAIPromptExecutionSettings = executionSettings;
}

if (string.IsNullOrEmpty(openAIPromptExecutionSettings.ChatSystemPrompt)
if (string.IsNullOrEmpty(openAIPromptExecutionSettings.ChatSystemPrompt)
&& !string.IsNullOrEmpty(systemPropmpt))
{
openAIPromptExecutionSettings.ChatSystemPrompt = systemPropmpt;
Expand All @@ -63,17 +62,14 @@ public AIConversationManager(Kernel kernel,
cancellationToken: ct);

return result;

}


public async ValueTask<string?> ProcessConversationAsync(
ChatHistory chatHistory,
string? systemMessage = null,
OpenAIPromptExecutionSettings? executionSettings = null,
CancellationToken ct = default)
{

var result = ProcessConversationStreamAsync(chatHistory, systemMessage, executionSettings, ct);

if (result != null)
Expand All @@ -90,4 +86,4 @@ public AIConversationManager(Kernel kernel,
return null;
}
}
}
}
32 changes: 32 additions & 0 deletions Core/Chats/KernelPluginsImporter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Plugins.Memory;
using SemanticKernel.Connectors.Memory.SqlServer;

namespace OpenAIExtensions.Chats
{
public static class KernelPluginsImporter
{
public static async void ImportTextMemoryPlugin(
this Kernel kernel,
string connectionString,
string schema,
CancellationToken ct = default)
{
var config = new SqlServerConfig()
{
ConnectionString = connectionString,
Schema = schema ?? "ai"
};

var sqlMemoryStore = await SqlServerMemoryStore
.ConnectAsync(connectionString: connectionString, config, cancellationToken: ct);

var embeddingGenerator = kernel.GetRequiredService<ITextEmbeddingGenerationService>();

var semanticTextMemory = new SemanticTextMemory(sqlMemoryStore, embeddingGenerator);
kernel.ImportPluginFromObject(new TextMemoryPlugin(semanticTextMemory));
}
}
}
91 changes: 91 additions & 0 deletions Core/KernelMemoryFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using KernelMemory.MemoryStorage.SqlServer;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.AzureOpenAI;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.SemanticKernel;

namespace OpenAIExtensions
{
public class CreateKernelMemoryRequest
{
public required string Endpoint { get; set; }
public required string ApiKey { get; set; }
public required string ConnectionString { get; set; }
public string? Schema { get; set; }
}

public static class KernelMemoryFactory
{
public static IKernelMemory Create(CreateKernelMemoryRequest request)
{
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}

var kernelMemoryBuilder = new KernelMemoryBuilder()
.WithAzureOpenAIDefaults(
request.Endpoint,
request.ApiKey);

var config = new SqlServerConfig()
{
ConnectionString = request.ConnectionString,
Schema = request.Schema ?? "ai"
};

var kernelMemory = kernelMemoryBuilder
.WithSqlServerMemoryDb(config)
.Build<MemoryServerless>();

return kernelMemory;
}

public static IKernelMemoryBuilder WithAzureOpenAIDefaults(
this IKernelMemoryBuilder builder,
string endpoint,
string apiKey,
ITextTokenizer? textGenerationTokenizer = null,
ITextTokenizer? textEmbeddingTokenizer = null,
ILoggerFactory? loggerFactory = null,
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
textGenerationTokenizer ??= new DefaultGPTTokenizer();
textEmbeddingTokenizer ??= new DefaultGPTTokenizer();

var textEmbbedingAIConfig = new AzureOpenAIConfig
{
APIKey = apiKey,
Endpoint = endpoint,
Deployment = "text-embedding-ada-002",
MaxRetries = 3,
Auth = AzureOpenAIConfig.AuthTypes.APIKey,
MaxTokenTotal = 8191,
};

var textGenerationAIConfig = new AzureOpenAIConfig
{
APIKey = apiKey,
Endpoint = endpoint,
Deployment = "gpt-35-turbo-0613",
MaxRetries = 3,
Auth = AzureOpenAIConfig.AuthTypes.APIKey,
MaxTokenTotal = 16384,
};

textEmbbedingAIConfig.Validate();
textGenerationAIConfig.Validate();
builder.Services.AddAzureOpenAIEmbeddingGeneration(textEmbbedingAIConfig, textEmbeddingTokenizer, httpClient);
builder.Services.AddAzureOpenAITextGeneration(textGenerationAIConfig, textGenerationTokenizer, httpClient);
if (!onlyForRetrieval)
{
builder.AddIngestionEmbeddingGenerator(new AzureOpenAITextEmbeddingGenerator(textEmbbedingAIConfig, textEmbeddingTokenizer, loggerFactory, httpClient));
}

return builder;
}
}
}
10 changes: 7 additions & 3 deletions Core/OpenAIExtensions.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<NoWarn>SKEXP0001,SKEXP0010, SKEXP0011, SKEXP0050</NoWarn>
</PropertyGroup>

<PropertyGroup>
Expand All @@ -21,15 +22,18 @@
<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.10.4" />
<PackageReference Include="Microsoft.Extensions.Http" Version="8.0.0" />
<PackageReference Include="Microsoft.KernelMemory.SemanticKernelPlugin" Version="0.34.240313.1" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="1.6.2" />
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.6.2" />
<PackageReference Include="Microsoft.SemanticKernel.Connectors.OpenAI" Version="1.6.2" />
<PackageReference Include="Microsoft.KernelMemory.SemanticKernelPlugin" Version="0.34.240313.1">
<TreatAsUsed>true</TreatAsUsed>
</PackageReference>
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Memory" Version="1.6.2-alpha" />
<PackageReference Include="Microsoft.KernelMemory.Abstractions" Version="0.34.240313.1" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.34.240313.1" />
<PackageReference Include="KernelMemory.MemoryStorage.SqlServer" Version="1.3.3" />
<PackageReference Include="SemanticKernel.Connectors.Memory.SqlServer" Version="1.3.3" />
</ItemGroup>

</Project>
20 changes: 3 additions & 17 deletions Core/SematicKernelBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.KernelMemory;
using Microsoft.SemanticKernel;
using OpenAIExtensions.HttpClients;

Expand Down Expand Up @@ -30,7 +31,6 @@ public static SematicKernelBuilder Create(string defaultEndpoint, string default
throw new ArgumentException($"'{nameof(defaultKey)}' cannot be null or empty.", nameof(defaultKey));
}


return new SematicKernelBuilder()
{
DefaultEndpoint = defaultEndpoint,
Expand All @@ -52,6 +52,7 @@ public SematicKernelBuilder AddAIChatCompletion(
deploymentName: deploymentName,
endpoint: GetEndpoint(endpoint),
apiKey: GetApiKey(apiKey));

return this;
}

Expand All @@ -60,12 +61,10 @@ public SematicKernelBuilder AddAITextGeneration(
string? apiKey = null,
string deploymentName = "gpt-35-turbo-0613")
{
#pragma warning disable SKEXP0011 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
_kernelBuilder.AddAzureOpenAITextGeneration(
deploymentName: deploymentName,
endpoint: GetEndpoint(endpoint),
apiKey: GetApiKey(apiKey));
#pragma warning restore SKEXP0011 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

return this;
}
Expand All @@ -75,12 +74,10 @@ public SematicKernelBuilder AddAITextEmbeddingGeneration(
string? apiKey = null,
string deploymentName = "text-embedding-ada-002")
{
#pragma warning disable SKEXP0010 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
_kernelBuilder.AddAzureOpenAITextEmbeddingGeneration(
deploymentName: deploymentName,
endpoint: GetEndpoint(endpoint),
apiKey: GetApiKey(apiKey));
#pragma warning restore SKEXP0010 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

return this;
}
Expand All @@ -90,13 +87,10 @@ public SematicKernelBuilder AddAIAudioToText(
string? apiKey = null,
string deploymentName = "whisper-001")
{
#pragma warning disable SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
_kernelBuilder.AddAzureOpenAIAudioToText(
deploymentName: deploymentName,
endpoint: GetEndpoint(endpoint),
apiKey: GetApiKey(apiKey));
#pragma warning restore SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

return this;
}

Expand All @@ -105,26 +99,19 @@ public SematicKernelBuilder AddAITextToAudio(
string? apiKey = null,
string deploymentName = "gpt-35-turbo-0613")
{
#pragma warning disable SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
_kernelBuilder.AddAzureOpenAITextToAudio(
deploymentName: deploymentName,
endpoint: GetEndpoint(endpoint),
apiKey: GetApiKey(apiKey));
#pragma warning restore SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

return this;
}

public SematicKernelBuilder AddAIFiles(string? apiKey = null)
{
#pragma warning disable SKEXP0010 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
_kernelBuilder.AddOpenAIFiles(apiKey: GetApiKey(apiKey));
#pragma warning restore SKEXP0010 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

return this;
}


private string GetApiKey(string? apiKey)
{
if (!string.IsNullOrEmpty(apiKey))
Expand Down Expand Up @@ -152,12 +139,11 @@ public Kernel Build()
return _kernelBuilder.Build();
}


public SematicKernelBuilder AddPlugin<TPlugin>(string? pluginName = null)
where TPlugin : class
{
_kernelBuilder.Plugins.AddFromType<TPlugin>(pluginName);
return this;
}
}
}
}
3 changes: 1 addition & 2 deletions Core/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public static void AddOpenAI(this IServiceCollection services)
services.AddScoped<IAISqlGenerator, AISqlGenerator>();
services.AddScoped<IAIAudioService, IAIAudioService>();
services.AddScoped<IAIEmbeddingsGenerator, AIEmbeddingsGenerator>();

}
}
}
}
Loading

0 comments on commit 19309e5

Please sign in to comment.