黑马程序员技术交流社区

标题: 【上海校区】scala实现K邻近算法 [打印本页]

作者: 梦缠绕的时候    时间: 2019-1-25 09:55
标题: 【上海校区】scala实现K邻近算法
算法流程
1.计算中的set中每一个点与Xt的距离。
2.按距离增序排。
3.选择距离最小的前k个点。
4.确定前k个点所在的label的出现频率。
5.返回频率最高的label作为测试的结果。

import scala.collection.mutable.Map

object kNN {

  def getGroup(): Array[Array[Double]] = {
    // 二维坐标中的点
    return Array(Array(1.0, 1.1), Array(1.0, 1.0), Array(0, 0), Array(0, 0.1))
  }
  def getLabels(): Array[Char] = {
    // 点所对应的label
    return Array('A', 'A', 'B', 'B')
  }

  def classify0(inX: Array[Double], dataSet: Array[Array[Double]], labels: Array[Char], k: Int): Char = {
    val dataSetSize = dataSet.length
    val sortedDisIndicies = dataSet.map { x =>
      val v1 = x(0) - inX(0)
      val v2 = x(1) - inX(1)
      // 求目标点到所有点的距离
      v1 * v1 + v2 * v2
    }
      .zipWithIndex // 将距离加上索引值结果(2.21,0)(2.0,1)....
      .sortBy(f => f._1)
      .map(f => f._2)   // 将排好序的index取出
    var classsCount: Map[Char, Int] = Map.empty
    for (i <- 0 to k - 1) {
      // 取出k个本地标签库
      val voteIlabel = labels(sortedDisIndicies(i))
      classsCount(voteIlabel) = classsCount.getOrElse(voteIlabel, 0) + 1

    }
    classsCount.toArray.sortBy(f => -f._2).head._1
  }
  def main(args: Array[String]) {
    println(classify0(Array(2, 0), getGroup(), getLabels(), 3))
  }
}



作者: 不二晨    时间: 2019-2-14 14:48
奈斯




欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/) 黑马程序员IT技术论坛 X3.2