From 0b26c256209b74ae625187cdbbe809cbb6b41701 Mon Sep 17 00:00:00 2001 From: Andrey Shabarshov Date: Tue, 25 Jul 2023 09:22:02 +0100 Subject: [PATCH] DEEP-38 Initial work for configurable Measures --- DeepTrace/Controllers/DownloadController.cs | 4 +- DeepTrace/Data/ModelDefinition.cs | 4 +- DeepTrace/Data/TrainedModelDefinition.cs | 13 +++ DeepTrace/ML/IMeasure.cs | 11 +++ DeepTrace/ML/MLHelpers.cs | 2 +- DeepTrace/ML/MLProcessor.cs | 4 +- DeepTrace/ML/Measures.cs | 89 +++++++++++++++++++ DeepTrace/ML/MedianHelper.cs | 80 +++++++++++++++++ DeepTrace/Pages/Training.razor | 39 +++++++- DeepTrace/Program.cs | 3 +- DeepTrace/Services/IModelStorageService.cs | 2 +- .../Services/ITrainedModelStorageService.cs | 11 +++ DeepTrace/Services/ModelStorageService.cs | 6 +- .../Services/TrainedModelStorageService.cs | 58 ++++++++++++ 14 files changed, 312 insertions(+), 14 deletions(-) create mode 100644 DeepTrace/Data/TrainedModelDefinition.cs create mode 100644 DeepTrace/ML/IMeasure.cs create mode 100644 DeepTrace/ML/Measures.cs create mode 100644 DeepTrace/ML/MedianHelper.cs create mode 100644 DeepTrace/Services/ITrainedModelStorageService.cs create mode 100644 DeepTrace/Services/TrainedModelStorageService.cs diff --git a/DeepTrace/Controllers/DownloadController.cs b/DeepTrace/Controllers/DownloadController.cs index 8613d49..55b0fff 100644 --- a/DeepTrace/Controllers/DownloadController.cs +++ b/DeepTrace/Controllers/DownloadController.cs @@ -8,9 +8,9 @@ namespace DeepTrace.Controllers [Route("api/[controller]")] public class DownloadController : Controller { - private readonly IModelDefinitionService _modelService; + private readonly IModelStorageService _modelService; - public DownloadController(IModelDefinitionService modelService) + public DownloadController(IModelStorageService modelService) { _modelService = modelService; } diff --git a/DeepTrace/Data/ModelDefinition.cs b/DeepTrace/Data/ModelDefinition.cs index ac86b59..6bc3a0f 100644 --- a/DeepTrace/Data/ModelDefinition.cs +++ b/DeepTrace/Data/ModelDefinition.cs @@ -29,14 +29,14 @@ public class ModelDefinition { columnNames.AddRange(measureNames.Select(x => $"{item.Query}_{x}")); } - + columnNames.Add("Name"); return columnNames; } public string ToCsv() { var current = IntervalDefinitionList.First(); - var headers = string.Join(",", GetColumnNames().Select(x=>$"\"{x}\"")) + ",Name"; + var headers = string.Join(",", GetColumnNames().Select(x=>$"\"{x}\"")); var writer = new StringBuilder(); diff --git a/DeepTrace/Data/TrainedModelDefinition.cs b/DeepTrace/Data/TrainedModelDefinition.cs new file mode 100644 index 0000000..8dde9cb --- /dev/null +++ b/DeepTrace/Data/TrainedModelDefinition.cs @@ -0,0 +1,13 @@ +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson; + +namespace DeepTrace.Data +{ + public class TrainedModelDefinition + { + [BsonId] + public ObjectId? Id { get; set; } + public string Name { get; set; } + public byte[] Value { get; set; } //base64 + } +} diff --git a/DeepTrace/ML/IMeasure.cs b/DeepTrace/ML/IMeasure.cs new file mode 100644 index 0000000..ce497d9 --- /dev/null +++ b/DeepTrace/ML/IMeasure.cs @@ -0,0 +1,11 @@ +using PrometheusAPI; + +namespace DeepTrace.ML +{ + public interface IMeasure + { + public string Name { get; } + void Reset(); + float Calculate(IEnumerable data); + } +} diff --git a/DeepTrace/ML/MLHelpers.cs b/DeepTrace/ML/MLHelpers.cs index c6d6639..d9a446e 100644 --- a/DeepTrace/ML/MLHelpers.cs +++ b/DeepTrace/ML/MLHelpers.cs @@ -34,7 +34,7 @@ public static class MLHelpers var columnNames = model.GetColumnNames(); var columns = columnNames - .Select((x,i) => new TextLoader.Column(x, DataKind.Double, i)) + .Select((x,i) => new TextLoader.Column(x, DataKind.String, i)) .ToArray() ; diff --git a/DeepTrace/ML/MLProcessor.cs b/DeepTrace/ML/MLProcessor.cs index 807d2d5..48bea6a 100644 --- a/DeepTrace/ML/MLProcessor.cs +++ b/DeepTrace/ML/MLProcessor.cs @@ -13,12 +13,12 @@ namespace DeepTrace.ML private DataViewSchema? _schema; private ITransformer? _transformer; - private string Name { get; set; } + private string Name { get; set; } = "TestModel"; public async Task Train(ModelDefinition modelDef) { var pipeline = _estimatorBuilder.BuildPipeline(_mlContext, modelDef); - var (data, filename) = await MLHelpers.Convert(_mlContext,modelDef); + var (data, filename) = await MLHelpers.Convert(_mlContext, modelDef); try { _schema = data.Schema; diff --git a/DeepTrace/ML/Measures.cs b/DeepTrace/ML/Measures.cs new file mode 100644 index 0000000..61d548f --- /dev/null +++ b/DeepTrace/ML/Measures.cs @@ -0,0 +1,89 @@ +using PrometheusAPI; + +namespace DeepTrace.ML +{ + public class MeasureMin : IMeasure + { + public string Name => "Min"; + public float Calculate(IEnumerable data) => + data + .Where(x => x.Value != 0.0f) + .Min( x => x.Value ) + ; + + public void Reset() { } + } + + public class MeasureMax : IMeasure + { + public string Name => "Max"; + public float Calculate(IEnumerable data) => data.Max(x => x.Value); + public void Reset() { } + } + + public class MeasureAvg : IMeasure + { + public string Name => "Avg"; + public float Calculate(IEnumerable data) => data.Average(x => x.Value); + public void Reset() { } + } + + /// + /// WARNING: Only works with fixed length interval + /// + public class MeasureSum : IMeasure + { + public string Name => "Sum"; + public float Calculate(IEnumerable data) => data.Sum(x => x.Value); + public void Reset() { } + } + + public class MeasureMedian : IMeasure + { + public string Name => "Median"; + + public float Calculate(IEnumerable data) + => MedianHelper.Median(data, x => x.Value); + + public void Reset() { } + + } + + public class MeasureDiff : IMeasure where T : IMeasure, new() + { + private T _measure = new(); + public string Name => "Diff_"+_measure.Name; + + private float _prev = float.NaN; + + public float Calculate(IEnumerable data) + { + var val = _measure.Calculate(data); + if (float.IsNaN(_prev)) + { + _prev = val; + return 0.0f; + } + + val = val - _prev; + _prev = val; + return val; + } + + public void Reset() + { + _measure.Reset(); + _prev = float.NaN; + } + } + + public class MeasureDiffMin : MeasureDiff { } + public class MeasureDiffMax : MeasureDiff { } + public class MeasureDiffAvg : MeasureDiff { } + /// + /// WARNING: Only works with fixed length interval + /// + public class MeasureDiffSum : MeasureDiff { } + public class MeasureDiffMedian : MeasureDiff { } + +} diff --git a/DeepTrace/ML/MedianHelper.cs b/DeepTrace/ML/MedianHelper.cs new file mode 100644 index 0000000..939b240 --- /dev/null +++ b/DeepTrace/ML/MedianHelper.cs @@ -0,0 +1,80 @@ +namespace DeepTrace.ML; + +/// +/// Calculate median. +/// https://stackoverflow.com/questions/4140719/calculate-median-in-c-sharp +/// +public static class MedianHelper +{ + /// + /// Partitions the given list around a pivot element such that all elements on left of pivot are <= pivot + /// and the ones at thr right are > pivot. This method can be used for sorting, N-order statistics such as + /// as median finding algorithms. + /// Pivot is selected ranodmly if random number generator is supplied else its selected as last element in the list. + /// Reference: Introduction to Algorithms 3rd Edition, Corman et al, pp 171 + /// + private static int Partition(this IList list, int start, int end, Random rnd = null) where T : IComparable + { + if (rnd != null) + list.Swap(end, rnd.Next(start, end + 1)); + + var pivot = list[end]; + var lastLow = start - 1; + for (var i = start; i < end; i++) + { + if (list[i].CompareTo(pivot) <= 0) + list.Swap(i, ++lastLow); + } + list.Swap(end, ++lastLow); + return lastLow; + } + + /// + /// Returns Nth smallest element from the list. Here n starts from 0 so that n=0 returns minimum, n=1 returns 2nd smallest element etc. + /// Note: specified list would be mutated in the process. + /// Reference: Introduction to Algorithms 3rd Edition, Corman et al, pp 216 + /// + public static T NthOrderStatistic(this IList list, int n, Random rnd = null) where T : IComparable + { + return NthOrderStatistic(list, n, 0, list.Count - 1, rnd); + } + private static T NthOrderStatistic(this IList list, int n, int start, int end, Random rnd) where T : IComparable + { + while (true) + { + var pivotIndex = list.Partition(start, end, rnd); + if (pivotIndex == n) + return list[pivotIndex]; + + if (n < pivotIndex) + end = pivotIndex - 1; + else + start = pivotIndex + 1; + } + } + + public static void Swap(this IList list, int i, int j) + { + if (i == j) //This check is not required but Partition function may make many calls so its for perf reason + return; + var temp = list[i]; + list[i] = list[j]; + list[j] = temp; + } + + /// + /// Note: specified list would be mutated in the process. + /// + public static T Median(IList list) where T : IComparable + { + return list.NthOrderStatistic((list.Count - 1) / 2); + } + + public static TValue Median(IEnumerable sequence, Func getValue) + where TValue : IComparable + { + var list = sequence.Select(getValue).ToList(); + var mid = (list.Count - 1) / 2; + return list.NthOrderStatistic(mid); + } +} diff --git a/DeepTrace/Pages/Training.razor b/DeepTrace/Pages/Training.razor index 9660beb..25276f7 100644 --- a/DeepTrace/Pages/Training.razor +++ b/DeepTrace/Pages/Training.razor @@ -6,11 +6,13 @@ @using DeepTrace.Controls; @using Microsoft.ML; @using PrometheusAPI; +@using System.Text; @inject PrometheusClient Prometheus @inject IDialogService DialogService @inject IDataSourceStorageService StorageService -@inject IModelDefinitionService ModelService +@inject IModelStorageService ModelService +@inject ITrainedModelStorageService TrainedModelService @inject IEstimatorBuilder EstimatorBuilder @inject NavigationManager NavManager @inject IJSRuntime Js @@ -64,11 +66,23 @@ Add Refresh + + + + Import + + + Export Train + files = new List(); + if (sources.Count > 0) _dataSources = sources; if (models.Count > 0) @@ -398,6 +415,18 @@ // await InvokeAsync(StateHasChanged); } + //Doesn't work + private async Task HandleImport(IBrowserFile file) + { + var result = new StringBuilder(); + var reader = new StreamReader(file.OpenReadStream(file.Size)); + + while (reader.Peek() >= 0) + result.AppendLine(await reader.ReadLineAsync()); + + result.ToString(); + } + private async Task HandleExport() { await Js.InvokeVoidAsync("open", $"{NavManager.BaseUri}api/download/mldata/{Uri.EscapeDataString(_modelForm.CurrentModel.Name)}", "_blank"); @@ -428,8 +457,14 @@ var mlProcessor = new MLProcessor(); await mlProcessor.Train(_modelForm!.CurrentModel); var bytes = mlProcessor.Export(); - + //save to Mongo + var trainedModel = new TrainedModelDefinition + { + Name = "TrainedModel", + Value = bytes + }; + await TrainedModelService.Store(trainedModel); } diff --git a/DeepTrace/Program.cs b/DeepTrace/Program.cs index a9db77a..20031db 100644 --- a/DeepTrace/Program.cs +++ b/DeepTrace/Program.cs @@ -16,7 +16,8 @@ builder.Services.AddHttpClient(c => c.BaseAddress = new UriBui builder.Services .AddSingleton( s => new MongoClient(builder.Configuration.GetValue("Connections:MongoDb") )) .AddSingleton() - .AddSingleton() + .AddSingleton() + .AddSingleton() .AddSingleton() ; diff --git a/DeepTrace/Services/IModelStorageService.cs b/DeepTrace/Services/IModelStorageService.cs index a94e915..bfa5d7a 100644 --- a/DeepTrace/Services/IModelStorageService.cs +++ b/DeepTrace/Services/IModelStorageService.cs @@ -6,7 +6,7 @@ using System.Text; namespace DeepTrace.Services { - public interface IModelDefinitionService + public interface IModelStorageService { Task Delete(ModelDefinition source, bool ignoreNotStored = false); Task> Load(); diff --git a/DeepTrace/Services/ITrainedModelStorageService.cs b/DeepTrace/Services/ITrainedModelStorageService.cs new file mode 100644 index 0000000..a719418 --- /dev/null +++ b/DeepTrace/Services/ITrainedModelStorageService.cs @@ -0,0 +1,11 @@ +using DeepTrace.Data; + +namespace DeepTrace.Services +{ + public interface ITrainedModelStorageService + { + Task Delete(TrainedModelDefinition source, bool ignoreNotStored = false); + Task> Load(); + Task Store(TrainedModelDefinition source); + } +} diff --git a/DeepTrace/Services/ModelStorageService.cs b/DeepTrace/Services/ModelStorageService.cs index eb4512e..f63dade 100644 --- a/DeepTrace/Services/ModelStorageService.cs +++ b/DeepTrace/Services/ModelStorageService.cs @@ -4,7 +4,7 @@ using MongoDB.Driver; namespace DeepTrace.Services { - public class ModelDefinitionService : IModelDefinitionService + public class ModelStorageService : IModelStorageService { private const string MongoDBDatabaseName = "DeepTrace"; @@ -12,7 +12,7 @@ namespace DeepTrace.Services private readonly IMongoClient _client; - public ModelDefinitionService(IMongoClient client) + public ModelStorageService(IMongoClient client) { _client = client; } @@ -51,7 +51,7 @@ namespace DeepTrace.Services } var db = _client.GetDatabase(MongoDBDatabaseName); - var collection = db.GetCollection(MongoDBCollection); + var collection = db.GetCollection(MongoDBCollection); await collection.DeleteOneAsync(filter: new BsonDocument("_id", source.Id)); } diff --git a/DeepTrace/Services/TrainedModelStorageService.cs b/DeepTrace/Services/TrainedModelStorageService.cs new file mode 100644 index 0000000..37ece57 --- /dev/null +++ b/DeepTrace/Services/TrainedModelStorageService.cs @@ -0,0 +1,58 @@ +using DeepTrace.Data; +using MongoDB.Bson; +using MongoDB.Driver; + +namespace DeepTrace.Services +{ + public class TrainedModelStorageService: ITrainedModelStorageService + { + private const string MongoDBDatabaseName = "DeepTrace"; + private const string MongoDBCollection = "TrainedModels"; + + private readonly IMongoClient _client; + + public TrainedModelStorageService(IMongoClient client) + { + _client = client; + } + + public async Task> Load() + { + var db = _client.GetDatabase(MongoDBDatabaseName); + var collection = db.GetCollection(MongoDBCollection); + + var res = await (await collection.FindAsync("{}")).ToListAsync(); + return res; + } + public async Task Store(TrainedModelDefinition source) + { + var db = _client.GetDatabase(MongoDBDatabaseName); + var collection = db.GetCollection(MongoDBCollection); + + if (source.Id == null) + source.Id = ObjectId.GenerateNewId(); + + // use upsert (insert or update) to automatically handle subsequent updates + await collection.ReplaceOneAsync( + filter: new BsonDocument("_id", source.Id), + options: new ReplaceOptions { IsUpsert = true }, + replacement: source + ); + } + + public async Task Delete(TrainedModelDefinition source, bool ignoreNotStored = false) + { + if (source.Id == null) + { + if (!ignoreNotStored) + throw new InvalidDataException("Source was not stored yet. There is nothing to delete"); + return; + } + + var db = _client.GetDatabase(MongoDBDatabaseName); + var collection = db.GetCollection(MongoDBCollection); + + await collection.DeleteOneAsync(filter: new BsonDocument("_id", source.Id)); + } + } +}