1. 预备测试数据
2. 加载模型
3. 训练
4. 预测
实现:
TaxiFarePrediction.cs:
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
namespace _01_TaxiFare
{
public class TaxiFarePrediction
{
static readonly string _datapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-train.csv");
static readonly string _testdatapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-test.csv");
static readonly string _modelpath = Path.Combine(Environment.CurrentDirectory, "Model.zip");
public static async Task<TaxiTripFarePrediction> Predict(TaxiTrip tt)
{
var model = await Train();
Evaluate(model);
return model.Predict(tt);
}
private static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
{
var pipeline = new LearningPipeline
{
new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','),
new ColumnCopier(("FareAmount", "Label")),
new CategoricalOneHotVectorizer(
"VendorId",
"RateCode",
"PaymentType"),
new ColumnConcatenator(
"Features",
"VendorId",
"RateCode",
"PassengerCount",
"TripDistance",
"PaymentType"),
new FastTreeRegressor()
};
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
await model.WriteAsync(_modelpath);
return model;
}
private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
{
var testData = new TextLoader(_testdatapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ',');
var evaluator = new RegressionEvaluator();
RegressionMetrics metrics = evaluator.Evaluate(model, testData);
Console.WriteLine($"Rms = {metrics.Rms}");
Console.WriteLine($"RSquared = {metrics.RSquared}");
}
}
}
TaxiTrip.cs:
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime.Api;
namespace _01_TaxiFare
{
public class TaxiTrip
{
[Column("0")]
public string VendorId;
[Column("1")]
public string RateCode;
[Column("2")]
public float PassengerCount;
[Column("3")]
public float TripTime;
[Column("4")]
public float TripDistance;
[Column("5")]
public string PaymentType;
[Column("6")]
public float FareAmount;
}
public class TaxiTripFarePrediction
{
[ColumnName("Score")]
public float FareAmount;
}
}
调用:
using System;
namespace _01_TaxiFare
{
class Program
{
static void Main(string[] args)
{
var prediction = TaxiFarePrediction.Predict(new TaxiTrip
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripDistance = 10.33f,
PaymentType = "CSH",
FareAmount = 0 // predict it. actual = 29.5
}).Result;
Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.FareAmount);
Console.ReadLine();
}
}
}
---------------------
作者:_iorilan
来源:CSDN
原文:https://blog.csdn.net/lan_liang/article/details/84680443
版权声明:本文为博主原创文章,转载请附上博文链接!
|
|