A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

算法流程
1.计算中的set中每一个点与Xt的距离。
2.按距离增序排。
3.选择距离最小的前k个点。
4.确定前k个点所在的label的出现频率。
5.返回频率最高的label作为测试的结果。

import scala.collection.mutable.Map

object kNN {

  def getGroup(): Array[Array[Double]] = {
    // 二维坐标中的点
    return Array(Array(1.0, 1.1), Array(1.0, 1.0), Array(0, 0), Array(0, 0.1))
  }
  def getLabels(): Array[Char] = {
    // 点所对应的label
    return Array('A', 'A', 'B', 'B')
  }

  def classify0(inX: Array[Double], dataSet: Array[Array[Double]], labels: Array[Char], k: Int): Char = {
    val dataSetSize = dataSet.length
    val sortedDisIndicies = dataSet.map { x =>
      val v1 = x(0) - inX(0)
      val v2 = x(1) - inX(1)
      // 求目标点到所有点的距离
      v1 * v1 + v2 * v2
    }
      .zipWithIndex // 将距离加上索引值结果(2.21,0)(2.0,1)....
      .sortBy(f => f._1)
      .map(f => f._2)   // 将排好序的index取出
    var classsCount: Map[Char, Int] = Map.empty
    for (i <- 0 to k - 1) {
      // 取出k个本地标签库
      val voteIlabel = labels(sortedDisIndicies(i))
      classsCount(voteIlabel) = classsCount.getOrElse(voteIlabel, 0) + 1

    }
    classsCount.toArray.sortBy(f => -f._2).head._1
  }
  def main(args: Array[String]) {
    println(classify0(Array(2, 0), getGroup(), getLabels(), 3))
  }
}


1 个回复

倒序浏览
奈斯
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马