DEEP-13, DEEP-14 Training dialog implemented. Dashboard page UI and functionality added

This commit is contained in:
Andrey Shabarshov 2023-07-28 17:11:58 +01:00
parent facecd5ed7
commit af6c75a5ac
13 changed files with 315 additions and 30 deletions

View File

@ -0,0 +1,169 @@
@using DeepTrace.Data;
@using DeepTrace.ML;
@using DeepTrace.Services;
@using PrometheusAPI;
@inject PrometheusClient Prometheus
@inject IDialogService DialogService
@inject IModelStorageService ModelService
@inject ITrainedModelStorageService TrainedModelService
@inject ILogger<MLProcessor> MLProcessorLogger
@inject ILogger<ModelCard> Logger
<MudCard Class="mb-3">
<MudCardHeader>
<CardHeaderContent>
<MudText Typo="Typo.h6">@Model?.Name</MudText>
</CardHeaderContent>
<CardHeaderActions>
<MudSwitch @bind-Checked="IsEnabled"> @Model?.IsEnabled</MudSwitch>
</CardHeaderActions>
</MudCardHeader>
<MudCardContent>
<MudText>Current state: @_prediction.PredictedLabel</MudText>
</MudCardContent>
</MudCard>
@code{
[Parameter]
public TrainedModelDefinition? Model { get; set; }
private ModelDefinition _modelDefinition = new();
private Prediction _prediction = new();
protected override async Task OnAfterRenderAsync(bool firstRender)
{
if (!firstRender || Model?.Id == null)
{
return;
}
_modelDefinition = (await ModelService.Load(Model.Id)) ?? _modelDefinition;
#pragma warning disable CS4014
Task.Run(PredictionLoop);
#pragma warning restore CS4014
}
private bool IsEnabled
{
get => Model?.IsEnabled ?? false;
set
{
if (Model==null || Model.IsEnabled == value)
{
return;
}
Model.IsEnabled = value;
InvokeAsync(SaveIsEnabled);
}
}
private async Task SaveIsEnabled()
{
if(Model == null)
{
return;
}
var trainedModel = new TrainedModelDefinition
{
Id = Model.Id,
IsEnabled = Model.IsEnabled,
Name = Model.Name,
Value = Model.Value
};
await TrainedModelService.Store(trainedModel);
}
private async Task PredictionLoop()
{
var startDate = DateTime.UtcNow;
while (true)
{
try
{
await Task.Delay(TimeSpan.FromSeconds(5));
var endDate = DateTime.UtcNow;
await PredictAnomaly(startDate, endDate);
startDate = endDate;
}
catch(Exception)
{
//ignore
}
}
}
private async Task PredictAnomaly(DateTime startDate, DateTime endDate)
{
// use automatic step value to always request 500 elements
var seconds = (endDate - startDate).TotalSeconds / 500.0;
if (seconds < 1.0)
seconds = 1.0;
var step = TimeSpan.FromSeconds(seconds);
var tasks = _modelDefinition!.DataSource.Queries
.Select(x => Prometheus.RangeQuery(x.Query, startDate, endDate, step, TimeSpan.FromSeconds(2)))
.ToArray();
try
{
await Task.WhenAll(tasks);
}
catch (Exception e)
{
await ShowError(e.Message);
return;
}
var data = new List<TimeSeriesDataSet>();
foreach (var (res, def) in tasks.Select((x, i) => (x.Result, _modelDefinition.DataSource.Queries[i])))
{
if (res.Status != StatusType.Success)
{
Logger.LogError(res.Error ?? "Error");
return;
}
if (res.ResultType != ResultTypeType.Matrix)
{
Logger.LogError($"Got {res.ResultType}, but Matrix expected for {def.Query}");
return;
}
var m = res.AsMatrix().Result;
if (m == null || m.Length != 1)
{
Logger.LogError($"No data returned for {def.Query}");
return;
}
data.Add(
new()
{
Name = def.Query,
Color = def.Color,
Data = m[0].Values!.ToList()
}
);
}
var mlProcessor = new MLProcessor(MLProcessorLogger);
_prediction = await mlProcessor.Predict(Model, _modelDefinition, data);
}
private async Task ShowError(string text)
{
var options = new DialogOptions
{
CloseOnEscapeKey = true
};
var parameters = new DialogParameters();
parameters.Add("Text", text);
var d = DialogService.Show<Controls.Dialog>("Error", parameters, options);
await d.Result;
}
}

View File

@ -12,7 +12,15 @@
<DialogContent> <DialogContent>
<h4>@Text</h4> <h4>@Text</h4>
<MudTextField T="string" ReadOnly="true" Text="@_progressText"></MudTextField> <MudTextField T="string" ReadOnly="true" Text="@_progressText"></MudTextField>
@if (_isTraining == false)
{
<MudText>MicroAccuracy: @_evaluationMetrics!.MicroAccuracy.ToString("N2")</MudText>
<MudText>MacroAccuracy: @_evaluationMetrics!.MacroAccuracy.ToString("N2")</MudText>
<MudText>LogLoss: @_evaluationMetrics!.LogLoss.ToString("N2")</MudText>
<MudText>LogLossReduction: @_evaluationMetrics!.LogLossReduction.ToString("N2")</MudText>
}
</DialogContent> </DialogContent>
@ -28,6 +36,7 @@
private string _progressText = ""; private string _progressText = "";
private bool _isTraining = true; private bool _isTraining = true;
private MLEvaluationMetrics? _evaluationMetrics;
void Submit() => MudDialog?.Close(DialogResult.Ok(true)); void Submit() => MudDialog?.Close(DialogResult.Ok(true));
@ -38,7 +47,7 @@
return; return;
} }
await Processor.Train(Model, UpdateProgress); _evaluationMetrics = await Processor.Train(Model, UpdateProgress);
_isTraining = false; _isTraining = false;
await InvokeAsync(StateHasChanged); await InvokeAsync(StateHasChanged);
} }

View File

@ -43,12 +43,23 @@ public class ModelDefinition
writer.AppendLine(headers); writer.AppendLine(headers);
foreach (var currentInterval in IntervalDefinitionList) foreach (var currentInterval in IntervalDefinitionList)
{
var source = currentInterval.Data;
string data = ConvertToCsv(source);
data += "," + currentInterval.Name;
writer.AppendLine(data);
}
return writer.ToString();
}
public static string ConvertToCsv(List<TimeSeriesDataSet> source)
{ {
var data = ""; var data = "";
for (var i = 0; i < currentInterval.Data.Count; i++) for (var i = 0; i < source.Count; i++)
{ {
var queryData = currentInterval.Data[i]; var queryData = source[i];
var min = queryData.Data.Min(x => x.Value); var min = queryData.Data.Min(x => x.Value);
var max = queryData.Data.Max(x => x.Value); var max = queryData.Data.Max(x => x.Value);
var avg = queryData.Data.Average(x => x.Value); var avg = queryData.Data.Average(x => x.Value);
@ -57,10 +68,7 @@ public class ModelDefinition
data += min + "," + max + "," + avg + "," + mean + ","; data += min + "," + max + "," + avg + "," + mean + ",";
} }
data += currentInterval.Name;
writer.AppendLine(data);
}
return writer.ToString(); return data+"\"ignoreMe\"";
} }
} }

View File

@ -2,9 +2,11 @@
namespace DeepTrace.Data; namespace DeepTrace.Data;
public class MyPrediction public class Prediction
{ {
//vector to hold alert,score,p-value values [ColumnName(@"PredictedLabel")]
[VectorType(3)] public string PredictedLabel { get; set; }
public double[]? Prediction { get; set; }
[ColumnName(@"Score")]
public float[] Score { get; set; }
} }

View File

@ -7,6 +7,7 @@ namespace DeepTrace.Data
{ {
[BsonId] [BsonId]
public ObjectId? Id { get; set; } public ObjectId? Id { get; set; }
public bool IsEnabled { get; set; } = false;
public string Name { get; set; } = string.Empty; public string Name { get; set; } = string.Empty;
public byte[] Value { get; set; } = Array.Empty<byte>(); //base64 public byte[] Value { get; set; } = Array.Empty<byte>(); //base64
} }

View File

@ -5,8 +5,8 @@ namespace DeepTrace.ML;
public interface IMLProcessor public interface IMLProcessor
{ {
Task Train(ModelDefinition modelDef, Action<string> log); Task<MLEvaluationMetrics> Train(ModelDefinition modelDef, Action<string> log);
byte[] Export(); byte[] Export();
void Import(byte[] data); void Import(byte[] data);
string Predict(DataSourceDefinition dataSource); Task<Prediction> Predict(TrainedModelDefinition trainedModel, ModelDefinition model, List<TimeSeriesDataSet> data);
} }

View File

@ -0,0 +1,16 @@
namespace DeepTrace.ML
{
public class MLEvaluationMetrics
{
public MLEvaluationMetrics()
{
}
public double MicroAccuracy { get; set; }
public double MacroAccuracy { get; set; }
public double LogLoss { get; set; }
public double LogLossReduction { get; set; }
}
}

View File

@ -32,9 +32,14 @@ public static class MLHelpers
await File.WriteAllTextAsync(fileName, csv); await File.WriteAllTextAsync(fileName, csv);
return LoadFromCsv(mlContext, model, fileName);
}
public static (IDataView View, string FileName) LoadFromCsv(MLContext mlContext, ModelDefinition model, string fileName)
{
var columnNames = model.GetColumnNames(); var columnNames = model.GetColumnNames();
var columns = columnNames var columns = columnNames
.Select((x,i) => new TextLoader.Column(x, DataKind.String, i)) .Select((x, i) => new TextLoader.Column(x, DataKind.String, i))
.ToArray() .ToArray()
; ;

View File

@ -3,6 +3,7 @@ using Microsoft.ML;
using Microsoft.ML.Data; using Microsoft.ML.Data;
using PrometheusAPI; using PrometheusAPI;
using System.Data; using System.Data;
using static DeepTrace.MLModel1;
namespace DeepTrace.ML namespace DeepTrace.ML
{ {
@ -22,15 +23,19 @@ namespace DeepTrace.ML
private string Name { get; set; } = "TestModel"; private string Name { get; set; } = "TestModel";
public async Task Train(ModelDefinition modelDef, Action<string> log) public async Task<MLEvaluationMetrics> Train(ModelDefinition modelDef, Action<string> log)
{ {
var pipeline = _estimatorBuilder.BuildPipeline(_mlContext, modelDef); var pipeline = _estimatorBuilder.BuildPipeline(_mlContext, modelDef);
var (data, filename) = await MLHelpers.Convert(_mlContext, modelDef); var (data, filename) = await MLHelpers.Convert(_mlContext, modelDef);
DataOperationsCatalog.TrainTestData dataSplit = _mlContext.Data.TrainTestSplit(data, testFraction: 0.2);
_mlContext.Log += (_,e) => LogEvents(log, e); _mlContext.Log += (_,e) => LogEvents(log, e);
try try
{ {
_schema = data.Schema; _schema = data.Schema;
_transformer = pipeline.Fit(data); _transformer = pipeline.Fit(dataSplit.TrainSet);
return Evaluate(dataSplit.TestSet);
} }
finally finally
{ {
@ -49,6 +54,20 @@ namespace DeepTrace.ML
} }
private MLEvaluationMetrics Evaluate(IDataView testData)
{
var predictions = _transformer!.Transform(testData);
var metrics = _mlContext.MulticlassClassification.Evaluate(predictions, "Name");
var evaluationMetrics = new MLEvaluationMetrics()
{
MicroAccuracy = metrics.MicroAccuracy,
MacroAccuracy = metrics.MacroAccuracy,
LogLoss = metrics.LogLoss,
LogLossReduction = metrics.LogLossReduction,
};
return evaluationMetrics;
}
public byte[] Export() public byte[] Export()
{ {
if(_schema == null) if(_schema == null)
@ -89,12 +108,30 @@ namespace DeepTrace.ML
mem.Read(bytes, 0, bytes.Length); mem.Read(bytes, 0, bytes.Length);
(_mlContext, _schema, _transformer) = MLHelpers.ImportSingleModel(bytes); (_mlContext, _schema, _transformer) = MLHelpers.ImportSingleModel(bytes);
} }
public string Predict(DataSourceDefinition dataSourceDefinition) public async Task<Prediction> Predict(TrainedModelDefinition trainedModel, ModelDefinition model, List<TimeSeriesDataSet> data)
{ {
throw new NotImplementedException(); Import(trainedModel.Value);
var headers = string.Join(",", model.GetColumnNames().Select(x => $"\"{x}\""));
var row = ModelDefinition.ConvertToCsv(data);
var csv = headers+"\n"+row;
var fileName = Path.GetTempFileName();
try
{
await File.WriteAllTextAsync(fileName, csv);
var (dataView, _) = MLHelpers.LoadFromCsv(_mlContext, model, fileName);
var predictionEngine = _mlContext.Model.CreatePredictionEngine<IDataView, Prediction>(_transformer);
var prediction = predictionEngine.Predict(dataView);
return prediction;
}
finally
{
File.Delete(fileName);
}
} }
} }
} }

View File

@ -1,4 +1,9 @@
@page "/" @page "/"
@using DeepTrace.Data;
@using DeepTrace.Services;
@using DeepTrace.Controls;
@inject ITrainedModelStorageService TrainedModelService
<PageTitle>Index</PageTitle> <PageTitle>Index</PageTitle>
@ -6,4 +11,26 @@
Welcome to your new app. Welcome to your new app.
<SurveyPrompt Title="How is Blazor working for you?" /> @if (_trainedModels != null)
{
@foreach(TrainedModelDefinition model in _trainedModels)
{
<ModelCard Model="@model"/>
}
} else
{
<MudText>Nothing to display</MudText>
}
@code{
private List<TrainedModelDefinition> _trainedModels = new();
protected override async Task OnInitializedAsync()
{
base.OnInitialized();
_trainedModels = await TrainedModelService.Load();
}
}

View File

@ -553,6 +553,7 @@
var trainedModel = new TrainedModelDefinition var trainedModel = new TrainedModelDefinition
{ {
Id = _modelForm!.CurrentModel.Id, Id = _modelForm!.CurrentModel.Id,
IsEnabled = false,
Name = _modelForm!.CurrentModel.Name, Name = _modelForm!.CurrentModel.Name,
Value = bytes Value = bytes
}; };

View File

@ -10,6 +10,7 @@ namespace DeepTrace.Services
{ {
Task Delete(ModelDefinition source, bool ignoreNotStored = false); Task Delete(ModelDefinition source, bool ignoreNotStored = false);
Task<List<ModelDefinition>> Load(); Task<List<ModelDefinition>> Load();
Task<ModelDefinition?> Load(BsonObjectId id);
Task Store(ModelDefinition source); Task Store(ModelDefinition source);
} }
} }

View File

@ -25,6 +25,15 @@ namespace DeepTrace.Services
var res = await (await collection.FindAsync("{}")).ToListAsync(); var res = await (await collection.FindAsync("{}")).ToListAsync();
return res; return res;
} }
public async Task<ModelDefinition?> Load(BsonObjectId id)
{
var db = _client.GetDatabase(MongoDBDatabaseName);
var collection = db.GetCollection<ModelDefinition>(MongoDBCollection);
var res = (await (await collection.FindAsync($"{{_id:ObjectId(\"{id}\")}}")).ToListAsync()).FirstOrDefault();
return res;
}
public async Task Store(ModelDefinition source) public async Task Store(ModelDefinition source)
{ {
var db = _client.GetDatabase(MongoDBDatabaseName); var db = _client.GetDatabase(MongoDBDatabaseName);