A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

1. 预备测试数据
2. 加载模型
3. 训练
4. 预测

实现:

TaxiFarePrediction.cs:
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;

namespace _01_TaxiFare
{
    public class TaxiFarePrediction
    {
        static readonly string _datapath = Path.Combine(Environment.CurrentDirectory,  "taxi-fare-train.csv");
        static readonly string _testdatapath = Path.Combine(Environment.CurrentDirectory,  "taxi-fare-test.csv");
        static readonly string _modelpath = Path.Combine(Environment.CurrentDirectory, "Model.zip");

        public static async Task<TaxiTripFarePrediction> Predict(TaxiTrip tt)
        {
            var model = await Train();
            Evaluate(model);

            return model.Predict(tt);
        }

        private static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
        {

            var pipeline = new LearningPipeline
            {
                new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','),
                new ColumnCopier(("FareAmount", "Label")),
                new CategoricalOneHotVectorizer(
                    "VendorId",
                    "RateCode",
                    "PaymentType"),
                new ColumnConcatenator(
                    "Features",
                    "VendorId",
                    "RateCode",
                    "PassengerCount",
                    "TripDistance",
                    "PaymentType"),
                new FastTreeRegressor()
            };
            PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
            await model.WriteAsync(_modelpath);
            return model;
        }
        private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
        {
            var testData = new TextLoader(_testdatapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ',');
            var evaluator = new RegressionEvaluator();
            RegressionMetrics metrics = evaluator.Evaluate(model, testData);
            Console.WriteLine($"Rms = {metrics.Rms}");
            Console.WriteLine($"RSquared = {metrics.RSquared}");
        }
    }
}
TaxiTrip.cs:

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime.Api;

namespace _01_TaxiFare
{
    public class TaxiTrip
    {
        [Column("0")]
        public string VendorId;

        [Column("1")]
        public string RateCode;

        [Column("2")]
        public float PassengerCount;

        [Column("3")]
        public float TripTime;

        [Column("4")]
        public float TripDistance;

        [Column("5")]
        public string PaymentType;

        [Column("6")]
        public float FareAmount;
    }

    public class TaxiTripFarePrediction
    {
        [ColumnName("Score")]
        public float FareAmount;
    }
}
调用:

using System;

namespace _01_TaxiFare
{
    class Program
    {
        static void Main(string[] args)
        {
            var prediction = TaxiFarePrediction.Predict(new TaxiTrip
            {
                VendorId = "VTS",
                RateCode = "1",
                PassengerCount = 1,
                TripDistance = 10.33f,
                PaymentType = "CSH",
                FareAmount = 0 // predict it. actual = 29.5
            }).Result;

            Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.FareAmount);

            Console.ReadLine();
        }
    }
}


---------------------
作者:_iorilan
来源:CSDN
原文:https://blog.csdn.net/lan_liang/article/details/84680443
版权声明:本文为博主原创文章,转载请附上博文链接!

1 个回复

倒序浏览
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马