DEEP-13, DEEP-14 Training dialog implemented. Dashboard page UI and functionality added

This commit is contained in:
Andrey Shabarshov 2023-07-28 17:11:58 +01:00
parent facecd5ed7
commit af6c75a5ac
13 changed files with 315 additions and 30 deletions

View File

@ -0,0 +1,169 @@
@using DeepTrace.Data;
@using DeepTrace.ML;
@using DeepTrace.Services;
@using PrometheusAPI;
@inject PrometheusClient Prometheus
@inject IDialogService DialogService
@inject IModelStorageService ModelService
@inject ITrainedModelStorageService TrainedModelService
@inject ILogger<MLProcessor> MLProcessorLogger
@inject ILogger<ModelCard> Logger
<MudCard Class="mb-3">
<MudCardHeader>
<CardHeaderContent>
<MudText Typo="Typo.h6">@Model?.Name</MudText>
</CardHeaderContent>
<CardHeaderActions>
<MudSwitch @bind-Checked="IsEnabled"> @Model?.IsEnabled</MudSwitch>
</CardHeaderActions>
</MudCardHeader>
<MudCardContent>
<MudText>Current state: @_prediction.PredictedLabel</MudText>
</MudCardContent>
</MudCard>
@code{
[Parameter]
public TrainedModelDefinition? Model { get; set; }
private ModelDefinition _modelDefinition = new();
private Prediction _prediction = new();
protected override async Task OnAfterRenderAsync(bool firstRender)
{
if (!firstRender || Model?.Id == null)
{
return;
}
_modelDefinition = (await ModelService.Load(Model.Id)) ?? _modelDefinition;
#pragma warning disable CS4014
Task.Run(PredictionLoop);
#pragma warning restore CS4014
}
private bool IsEnabled
{
get => Model?.IsEnabled ?? false;
set
{
if (Model==null || Model.IsEnabled == value)
{
return;
}
Model.IsEnabled = value;
InvokeAsync(SaveIsEnabled);
}
}
private async Task SaveIsEnabled()
{
if(Model == null)
{
return;
}
var trainedModel = new TrainedModelDefinition
{
Id = Model.Id,
IsEnabled = Model.IsEnabled,
Name = Model.Name,
Value = Model.Value
};
await TrainedModelService.Store(trainedModel);
}
private async Task PredictionLoop()
{
var startDate = DateTime.UtcNow;
while (true)
{
try
{
await Task.Delay(TimeSpan.FromSeconds(5));
var endDate = DateTime.UtcNow;
await PredictAnomaly(startDate, endDate);
startDate = endDate;
}
catch(Exception)
{
//ignore
}
}
}
private async Task PredictAnomaly(DateTime startDate, DateTime endDate)
{
// use automatic step value to always request 500 elements
var seconds = (endDate - startDate).TotalSeconds / 500.0;
if (seconds < 1.0)
seconds = 1.0;
var step = TimeSpan.FromSeconds(seconds);
var tasks = _modelDefinition!.DataSource.Queries
.Select(x => Prometheus.RangeQuery(x.Query, startDate, endDate, step, TimeSpan.FromSeconds(2)))
.ToArray();
try
{
await Task.WhenAll(tasks);
}
catch (Exception e)
{
await ShowError(e.Message);
return;
}
var data = new List<TimeSeriesDataSet>();
foreach (var (res, def) in tasks.Select((x, i) => (x.Result, _modelDefinition.DataSource.Queries[i])))
{
if (res.Status != StatusType.Success)
{
Logger.LogError(res.Error ?? "Error");
return;
}
if (res.ResultType != ResultTypeType.Matrix)
{
Logger.LogError($"Got {res.ResultType}, but Matrix expected for {def.Query}");
return;
}
var m = res.AsMatrix().Result;
if (m == null || m.Length != 1)
{
Logger.LogError($"No data returned for {def.Query}");
return;
}
data.Add(
new()
{
Name = def.Query,
Color = def.Color,
Data = m[0].Values!.ToList()
}
);
}
var mlProcessor = new MLProcessor(MLProcessorLogger);
_prediction = await mlProcessor.Predict(Model, _modelDefinition, data);
}
private async Task ShowError(string text)
{
var options = new DialogOptions
{
CloseOnEscapeKey = true
};
var parameters = new DialogParameters();
parameters.Add("Text", text);
var d = DialogService.Show<Controls.Dialog>("Error", parameters, options);
await d.Result;
}
}

View File

@ -12,7 +12,15 @@
<DialogContent>
<h4>@Text</h4>
<MudTextField T="string" ReadOnly="true" Text="@_progressText"></MudTextField>
@if (_isTraining == false)
{
<MudText>MicroAccuracy: @_evaluationMetrics!.MicroAccuracy.ToString("N2")</MudText>
<MudText>MacroAccuracy: @_evaluationMetrics!.MacroAccuracy.ToString("N2")</MudText>
<MudText>LogLoss: @_evaluationMetrics!.LogLoss.ToString("N2")</MudText>
<MudText>LogLossReduction: @_evaluationMetrics!.LogLossReduction.ToString("N2")</MudText>
}
</DialogContent>
@ -28,6 +36,7 @@
private string _progressText = "";
private bool _isTraining = true;
private MLEvaluationMetrics? _evaluationMetrics;
void Submit() => MudDialog?.Close(DialogResult.Ok(true));
@ -38,7 +47,7 @@
return;
}
await Processor.Train(Model, UpdateProgress);
_evaluationMetrics = await Processor.Train(Model, UpdateProgress);
_isTraining = false;
await InvokeAsync(StateHasChanged);
}

View File

@ -44,23 +44,31 @@ public class ModelDefinition
foreach (var currentInterval in IntervalDefinitionList)
{
var data = "";
for (var i = 0; i < currentInterval.Data.Count; i++)
{
var queryData = currentInterval.Data[i];
var min = queryData.Data.Min(x => x.Value);
var max = queryData.Data.Max(x => x.Value);
var avg = queryData.Data.Average(x => x.Value);
var mean = queryData.Data.Sum(x => x.Value) / queryData.Data.Count;
data += min + "," + max + "," + avg + "," + mean + ",";
}
data += currentInterval.Name;
var source = currentInterval.Data;
string data = ConvertToCsv(source);
data += "," + currentInterval.Name;
writer.AppendLine(data);
}
return writer.ToString();
}
public static string ConvertToCsv(List<TimeSeriesDataSet> source)
{
var data = "";
for (var i = 0; i < source.Count; i++)
{
var queryData = source[i];
var min = queryData.Data.Min(x => x.Value);
var max = queryData.Data.Max(x => x.Value);
var avg = queryData.Data.Average(x => x.Value);
var mean = queryData.Data.Sum(x => x.Value) / queryData.Data.Count;
data += min + "," + max + "," + avg + "," + mean + ",";
}
return data+"\"ignoreMe\"";
}
}

View File

@ -2,9 +2,11 @@
namespace DeepTrace.Data;
public class MyPrediction
public class Prediction
{
//vector to hold alert,score,p-value values
[VectorType(3)]
public double[]? Prediction { get; set; }
[ColumnName(@"PredictedLabel")]
public string PredictedLabel { get; set; }
[ColumnName(@"Score")]
public float[] Score { get; set; }
}

View File

@ -7,6 +7,7 @@ namespace DeepTrace.Data
{
[BsonId]
public ObjectId? Id { get; set; }
public bool IsEnabled { get; set; } = false;
public string Name { get; set; } = string.Empty;
public byte[] Value { get; set; } = Array.Empty<byte>(); //base64
}

View File

@ -5,8 +5,8 @@ namespace DeepTrace.ML;
public interface IMLProcessor
{
Task Train(ModelDefinition modelDef, Action<string> log);
Task<MLEvaluationMetrics> Train(ModelDefinition modelDef, Action<string> log);
byte[] Export();
void Import(byte[] data);
string Predict(DataSourceDefinition dataSource);
Task<Prediction> Predict(TrainedModelDefinition trainedModel, ModelDefinition model, List<TimeSeriesDataSet> data);
}

View File

@ -0,0 +1,16 @@
namespace DeepTrace.ML
{
public class MLEvaluationMetrics
{
public MLEvaluationMetrics()
{
}
public double MicroAccuracy { get; set; }
public double MacroAccuracy { get; set; }
public double LogLoss { get; set; }
public double LogLossReduction { get; set; }
}
}

View File

@ -32,9 +32,14 @@ public static class MLHelpers
await File.WriteAllTextAsync(fileName, csv);
return LoadFromCsv(mlContext, model, fileName);
}
public static (IDataView View, string FileName) LoadFromCsv(MLContext mlContext, ModelDefinition model, string fileName)
{
var columnNames = model.GetColumnNames();
var columns = columnNames
.Select((x,i) => new TextLoader.Column(x, DataKind.String, i))
var columns = columnNames
.Select((x, i) => new TextLoader.Column(x, DataKind.String, i))
.ToArray()
;

View File

@ -3,6 +3,7 @@ using Microsoft.ML;
using Microsoft.ML.Data;
using PrometheusAPI;
using System.Data;
using static DeepTrace.MLModel1;
namespace DeepTrace.ML
{
@ -22,15 +23,19 @@ namespace DeepTrace.ML
private string Name { get; set; } = "TestModel";
public async Task Train(ModelDefinition modelDef, Action<string> log)
public async Task<MLEvaluationMetrics> Train(ModelDefinition modelDef, Action<string> log)
{
var pipeline = _estimatorBuilder.BuildPipeline(_mlContext, modelDef);
var (data, filename) = await MLHelpers.Convert(_mlContext, modelDef);
DataOperationsCatalog.TrainTestData dataSplit = _mlContext.Data.TrainTestSplit(data, testFraction: 0.2);
_mlContext.Log += (_,e) => LogEvents(log, e);
try
{
_schema = data.Schema;
_transformer = pipeline.Fit(data);
_transformer = pipeline.Fit(dataSplit.TrainSet);
return Evaluate(dataSplit.TestSet);
}
finally
{
@ -49,6 +54,20 @@ namespace DeepTrace.ML
}
private MLEvaluationMetrics Evaluate(IDataView testData)
{
var predictions = _transformer!.Transform(testData);
var metrics = _mlContext.MulticlassClassification.Evaluate(predictions, "Name");
var evaluationMetrics = new MLEvaluationMetrics()
{
MicroAccuracy = metrics.MicroAccuracy,
MacroAccuracy = metrics.MacroAccuracy,
LogLoss = metrics.LogLoss,
LogLossReduction = metrics.LogLossReduction,
};
return evaluationMetrics;
}
public byte[] Export()
{
if(_schema == null)
@ -89,12 +108,30 @@ namespace DeepTrace.ML
mem.Read(bytes, 0, bytes.Length);
(_mlContext, _schema, _transformer) = MLHelpers.ImportSingleModel(bytes);
}
public string Predict(DataSourceDefinition dataSourceDefinition)
public async Task<Prediction> Predict(TrainedModelDefinition trainedModel, ModelDefinition model, List<TimeSeriesDataSet> data)
{
throw new NotImplementedException();
Import(trainedModel.Value);
var headers = string.Join(",", model.GetColumnNames().Select(x => $"\"{x}\""));
var row = ModelDefinition.ConvertToCsv(data);
var csv = headers+"\n"+row;
var fileName = Path.GetTempFileName();
try
{
await File.WriteAllTextAsync(fileName, csv);
var (dataView, _) = MLHelpers.LoadFromCsv(_mlContext, model, fileName);
var predictionEngine = _mlContext.Model.CreatePredictionEngine<IDataView, Prediction>(_transformer);
var prediction = predictionEngine.Predict(dataView);
return prediction;
}
finally
{
File.Delete(fileName);
}
}
}
}

View File

@ -1,4 +1,9 @@
@page "/"
@using DeepTrace.Data;
@using DeepTrace.Services;
@using DeepTrace.Controls;
@inject ITrainedModelStorageService TrainedModelService
<PageTitle>Index</PageTitle>
@ -6,4 +11,26 @@
Welcome to your new app.
<SurveyPrompt Title="How is Blazor working for you?" />
@if (_trainedModels != null)
{
@foreach(TrainedModelDefinition model in _trainedModels)
{
<ModelCard Model="@model"/>
}
} else
{
<MudText>Nothing to display</MudText>
}
@code{
private List<TrainedModelDefinition> _trainedModels = new();
protected override async Task OnInitializedAsync()
{
base.OnInitialized();
_trainedModels = await TrainedModelService.Load();
}
}

View File

@ -553,6 +553,7 @@
var trainedModel = new TrainedModelDefinition
{
Id = _modelForm!.CurrentModel.Id,
IsEnabled = false,
Name = _modelForm!.CurrentModel.Name,
Value = bytes
};

View File

@ -10,6 +10,7 @@ namespace DeepTrace.Services
{
Task Delete(ModelDefinition source, bool ignoreNotStored = false);
Task<List<ModelDefinition>> Load();
Task<ModelDefinition?> Load(BsonObjectId id);
Task Store(ModelDefinition source);
}
}

View File

@ -25,6 +25,15 @@ namespace DeepTrace.Services
var res = await (await collection.FindAsync("{}")).ToListAsync();
return res;
}
public async Task<ModelDefinition?> Load(BsonObjectId id)
{
var db = _client.GetDatabase(MongoDBDatabaseName);
var collection = db.GetCollection<ModelDefinition>(MongoDBCollection);
var res = (await (await collection.FindAsync($"{{_id:ObjectId(\"{id}\")}}")).ToListAsync()).FirstOrDefault();
return res;
}
public async Task Store(ModelDefinition source)
{
var db = _client.GetDatabase(MongoDBDatabaseName);