diff --git a/DeepTrace/ML/IMLProcessor.cs b/DeepTrace/ML/IMLProcessor.cs new file mode 100644 index 0000000..8d7ff5a --- /dev/null +++ b/DeepTrace/ML/IMLProcessor.cs @@ -0,0 +1,12 @@ +using DeepTrace.Data; +using PrometheusAPI; + +namespace DeepTrace.ML; + +public interface IMLProcessor +{ + void Fit(ModelDefinition modelDef, DataSourceDefinition dataSourceDef); + byte[] Export(); + void Import(byte[] data); + string Predict(TimeSeries[] data); +} diff --git a/DeepTrace/ML/SpikeDetector.cs b/DeepTrace/ML/SpikeDetector.cs new file mode 100644 index 0000000..25314c4 --- /dev/null +++ b/DeepTrace/ML/SpikeDetector.cs @@ -0,0 +1,102 @@ +using DeepTrace.Data; +using Microsoft.ML; +using Microsoft.ML.Data; +using PrometheusAPI; +using System.Data; +using System.Linq; + +namespace DeepTrace.ML +{ + public class SpikeDetector : IMLProcessor + { + private readonly Dictionary _model = new(); + + public void Fit(ModelDefinition modelDef, DataSourceDefinition dataSourceDef) + { + var models = dataSourceDef.Queries + .Select( (x,i) => + { + // since we are just detecting spikes here we can combine all the time series into one + + List data = modelDef.IntervalDefinitionList[i].Data + .Select(y => y.Data) + .Aggregate>((acc, list) => acc.Concat(list)) + .ToList(); + + return (Name: x.Query, Data: data); + }) + .ToList(); + + foreach (var (name, data) in models) + { + _model[name] = FitOne(data); + } + } + + private static string _signature = "DeepTrace-Model-v1-"+typeof(SpikeDetector).Name; + + public byte[] Export() + { + using var mem = new MemoryStream(); + mem.WriteString(_signature); + mem.WriteInt(_model.Count); + + foreach ( var (name, model) in _model) + { + mem.WriteString(name); + model.Context.Model.Save(model.Transformer, model.Schema, mem); + } + + return mem.ToArray(); + } + + public void Import(byte[] data) + { + var mem = new MemoryStream(data); + var sig = mem.ReadString(); + if (sig != _signature) + throw new ApplicationException($"Wrong data for {GetType().Name}"); + + var count = mem.ReadInt(); + + for ( var i = 0; i < count; i++ ) + { + var name = mem.ReadString(); + + var mlContext = new MLContext(); + var transformer = mlContext.Model.Load(mem, out var schema); + + _model[name] = (mlContext, schema, transformer); + } + } + + public string Predict(TimeSeries[] data) + { + throw new NotImplementedException(); + } + + // -------------------------- internals + + + + class SpikePrediction + { + [VectorType(3)] + public double[] Prediction { get; set; } = new double[3]; + } + + private static (MLContext Context, DataViewSchema Schema, ITransformer Transformer) FitOne(List dataSet) + { + var mlContext = new MLContext(); + var dataView = mlContext.Data.LoadFromEnumerable(dataSet); + + const string outputColumnName = nameof(SpikePrediction.Prediction); + const string inputColumnName = nameof(TimeSeries.Value); + + var iidSpikeEstimator = mlContext.Transforms.DetectIidSpike(outputColumnName,inputColumnName, 95.0d, dataSet.Count); + var transformer = iidSpikeEstimator.Fit(dataView); + + return (mlContext, dataView.Schema, transformer); + } + } +} diff --git a/DeepTrace/ML/StreamUtils.cs b/DeepTrace/ML/StreamUtils.cs new file mode 100644 index 0000000..34085c8 --- /dev/null +++ b/DeepTrace/ML/StreamUtils.cs @@ -0,0 +1,39 @@ +using System.Text; + +namespace DeepTrace.ML; + +public static class StreamUtils +{ + private static readonly int SizeOfInt = BitConverter.GetBytes(1).Length; + + public static Stream WriteInt(this Stream stream, int value) + { + stream.Write(BitConverter.GetBytes(value)); + return stream; + } + + public static int ReadInt(this Stream stream) + { + var buffer= new byte[SizeOfInt]; + stream.Read(buffer, 0, SizeOfInt); + return BitConverter.ToInt32(buffer); + } + + public static Stream WriteString(this Stream stream, string value) + { + var utf8 = Encoding.UTF8.GetBytes(value); + stream.WriteInt(utf8.Length); + stream.Write(utf8); + + return stream; + } + + public static string ReadString(this Stream stream) + { + var len = stream.ReadInt(); + var utf8 = new byte[len]; + stream.Read(utf8, 0, len); + + return Encoding.UTF8.GetString(utf8); + } +}