A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

本帖最后由 我是楠楠 于 2020-4-28 15:26 编辑

【郑州校区】基于mahout的推荐系统 下

3. 代码实现
3.1 数据准备
创建商品表
[AppleScript] 纯文本查看 复制代码
CREATE TABLE `tb_item` (
  `pid` bigint(11) NOT NULL AUTO_INCREMENT,
  `name` varchar(2000) CHARACTER SET latin1 DEFAULT NULL,
  `types` varchar(2000) CHARACTER SET latin1 DEFAULT NULL,
  PRIMARY KEY (`pid`)
) ENGINE=InnoDB AUTO_INCREMENT=65134 DEFAULT CHARSET=utf8;
创建用户偏好表
[AppleScript] 纯文本查看 复制代码
CREATE TABLE `user_pianhao_data1` (
  `id` bigint(11) NOT NULL AUTO_INCREMENT,
  `uid` bigint(11) DEFAULT NULL,
  `pid` bigint(11) DEFAULT NULL,
  `val` bigint(11) DEFAULT NULL,
  `time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8;
表中测试数据较多,可在网盘下载数据进行初始化链接:https://pan.baidu.com/s/1VSgD2uiJ69zDsw-KCtfEUQ
3.2 搭建springboot工程


3.3 编写application.properties配置文件
[AppleScript] 纯文本查看 复制代码
#DB Configuration:
spring.datasource.driverClassName=com.mysql.jdbc.Driver
spring.datasource.url=jdbc:mysql://127.0.0.1:3306/recommend?useUnicode=true&characterEncoding=utf8
spring.datasource.username=root
spring.datasource.password=123456

#spring集成Mybatis环境
#pojo别名扫描包
mybatis.type-aliases-package=cn.itcast.domain
#加载Mybatis映射文件
mybatis.mapper-locations=classpath:mapper/*Mapper.xml

3.4 编写domain实体类
[AppleScript] 纯文本查看 复制代码
package cn.itcast.domain;
public class Item {
    private Long pid;
    private String name;
    private String types;

    public Long getPid() {
        return pid;
    }
    public void setPid(Long pid) {
        this.pid = pid;
    }
    public String getName() {
        return name;
    }
    public void setName(String name) {
        this.name = name;
    }
    public String getTypes() {
        return types;
    }
    public void setTypes(String types) {
        this.types = types;
    }
    @Override
    public String toString() {
        return "Item{" +
                "pid=" + pid +
                ", name='" + name + '\'' +
                ", types='" + types + '\'' +
                '}';
    }
}

3.5 编写Dao层的mapper接口和xml映射文件
[AppleScript] 纯文本查看 复制代码
package cn.itcast.dao;

import cn.itcast.domain.Item;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface ItemMapper {
    public List<Item> findAllByIds(@Param("Ids") List<Long> Ids);
}

[AppleScript] 纯文本查看 复制代码
<?xml version="1.0" encoding="utf-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
        "http://mybatis.org/dtd/mybatis-3-mapper.dtd" >
<mapper namespace="cn.itcast.dao.ItemMapper">
    <select id="findAllByIds" resultType="item">
      select * from tb_item
        WHERE pid in
        <foreach collection="Ids" item="id" open="(" close=")" separator=",">
            #{id}
        </foreach>
    </select>
</mapper>

3.6 编写MyConfig配置类和Service层
MyConfig配置类:用于设置基于数据库的DataModel模型。
[AppleScript] 纯文本查看 复制代码
package cn.itcast.myconfig;

import com.mysql.jdbc.jdbc2.optional.MysqlDataSource;
import org.apache.mahout.cf.taste.impl.model.jdbc.MySQLJDBCDataModel;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.JDBCDataModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class MyConfig {
    @Bean
    public DataModel getMySQLDataModel(){
        MysqlDataSource dataSource=new MysqlDataSource();
        dataSource.setServerName("localhost");
        dataSource.setUser("root");
        dataSource.setPassword("123456");
        dataSource.setDatabaseName("recommend");//数据库名字
        //参数1:mysql数据源信息,参数2:表名,参数3:用户列字段,参数4:商品列字段,参数5:偏好值字段,参数6:时间戳
        JDBCDataModel dataModel=new MySQLJDBCDataModel(dataSource,"user_pianhao_data1","uid","pid","val", "time");
      
                   /**
         *  DataModel可基于数据也可基于文件
         *  文件汇总数据格式
         *  用户id::商品id::偏好分值::时间戳
         *  1::122::5::838985046
         *  1::185::5::838983525
         *  1::231::5::838983392
         *  .........
         */
        //File file = new File("E:\\initData.dat");
        //try {
        //   DataModel dataModel = new GroupLensDataModel(file);
        //} catch (IOException e) {
        //    e.printStackTrace();
        //}
      
        return dataModel;
    }
}
Service接口:
[AppleScript] 纯文本查看 复制代码
package cn.itcast.service;

import cn.itcast.domain.Item;
import java.util.List;

public interface RecommendService {
          //基于用户的商品推荐
    List<Item> getRecommendItemsByUser(Long userId, int howMany);
    //基于内容的商品推荐
    List<Item> getRecommendItemsByItem(Long userId, Long itemId, int howMany);
}
Service实现类:
[AppleScript] 纯文本查看 复制代码
package cn.itcast.service.impl;

import cn.itcast.dao.ItemMapper;
import cn.itcast.domain.Item;
import cn.itcast.service.RecommendService;
import com.mysql.jdbc.jdbc2.optional.MysqlDataSource;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.model.jdbc.MySQLJDBCDataModel;
import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
import org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.JDBCDataModel;
import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.recommender.Recommender;
import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
import org.apache.mahout.cf.taste.similarity.UserSimilarity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;

@Service
public class RecommendServiceImpl implements RecommendService {
    @Autowired
    private ItemMapper itemMapper;

    @Autowired
    private DataModel dataModel;

    @Override
    public List<Item> getRecommendItemsByUser(Long userId, int howMany) {
        List<Item> list = null;
        try {
            //计算相似度,相似度算法有很多种,采用基于皮尔逊相关性的相似度 
            UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
            //计算最近邻域,邻居有两种算法,基于固定数量的邻居和基于相似度的邻居,这里使用基于固定数量的邻居
            UserNeighborhood userNeighborhood = new NearestNUserNeighborhood(100, similarity, dataModel);
            //构建推荐器,基于用户的协同过滤推荐
            Recommender recommender = new GenericUserBasedRecommender(dataModel, userNeighborhood, similarity);
            long start = System.currentTimeMillis();
            //推荐商品
            List<RecommendedItem> recommendedItemList = recommender.recommend(userId, howMany);
            List<Long> itemIds = new ArrayList<Long>();
            for (RecommendedItem recommendedItem : recommendedItemList) {
                System.out.println(recommendedItem);
                itemIds.add(recommendedItem.getItemID());
            }
            System.out.println("推荐出来的商品id集合"+itemIds);

            //根据商品id查询商品
            if(itemIds!=null &&itemIds.size()>0) {
                list = itemMapper.findAllByIds(itemIds);
            }else{
                list = new ArrayList<>();
            }
            System.out.println("推荐数量:"+list.size() +"耗时:"+(System.currentTimeMillis()-start));
        } catch (TasteException e) {
            e.printStackTrace();
        }
        return list;
    }

    @Override
    public List<Item> getRecommendItemsByItem(Long userId, Long itemId, int howMany) {

        List<Item> list = null;
        try {
            //计算相似度,相似度算法有很多种,采用基于皮尔逊相关性的相似度 
            ItemSimilarity itemSimilarity = new PearsonCorrelationSimilarity(dataModel);
            //4)构建推荐器,使用基于物品的协同过滤推荐
            GenericItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel, itemSimilarity);
            long start = System.currentTimeMillis();
            // 物品推荐相似度,计算两个物品同时出现的次数,次数越多任务的相似度越高。
            List<RecommendedItem> recommendedItemList = recommender.recommendedBecause(userId, itemId, howMany);
            //打印推荐的结果
            List<Long> itemIds = new ArrayList<Long>();
            for (RecommendedItem recommendedItem : recommendedItemList) {
                System.out.println(recommendedItem);
                itemIds.add(recommendedItem.getItemID());
            }
            System.out.println("推荐出来的商品id集合"+itemIds);

            //根据商品id查询商品
            if(itemIds!=null &&itemIds.size()>0) {
                list = itemMapper.findAllByIds(itemIds);
            }else{
                list = new ArrayList<>();
            }
            System.out.println("推荐数量:"+list.size() +"耗时:"+(System.currentTimeMillis()-start));
        } catch (TasteException e) {
            e.printStackTrace();
        }
        return list;
    }
}
3.7 编写controller
[AppleScript] 纯文本查看 复制代码
package cn.itcast.controller;

import cn.itcast.domain.Item;
import cn.itcast.service.RecommendService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;

@RestController
public class RecommendController {

    @Autowired
    private RecommendService recommendService;

    /**
     * 基于用户的推荐
     * @param userId 用户id
     * @param num 推荐数量
     * @return
     */
    @RequestMapping("recommendByUser")
    public List<Item> getRecommendItemsByUser(Long userId, int num){
        List<Item> items= recommendService.getRecommendItemsByUser(userId,num);
        return items;
    }
    /**
     * 基于内容的推荐
     * @param userId 用户id
     * @param itemId 商品id
     * @param num 推荐数量
     * @return
     */
    @RequestMapping("recommendByItem")
    public List<Item> getRecommendItemsByItem(Long userId,Long itemId, int num){
        List<Item> items= recommendService.getRecommendItemsByItem(userId,itemId, num);
        return items;
    }
}

3.8 测试
打开浏览器访问http://localhost:8080/recommendByUser?userId=5&num=10, 给id为5的用户推荐10个商品,输出:
[AppleScript] 纯文本查看 复制代码
[{"pid":50,"name":"Usual Suspects, The (1995)","types":"Crime|Mystery|Thriller"},{"pid":260,"name":"Star Wars: Episode IV - A New Hope (a.k.a. Star Wars) (1977)","types":"Action|Adventure|Sci-Fi"},{"pid":590,"name":"Dances with Wolves (1990)","types":"Adventure|Drama|Western"},{"pid":1732,"name":"Big Lebowski, The (1998)","types":"Comedy|Crime|Mystery|Thriller"},{"pid":2335,"name":"Waterboy, The (1998)","types":"Comedy"},{"pid":2478,"name":"Three Amigos (1986)","types":"Comedy|Western"},{"pid":4027,"name":"O Brother, Where Art Thou? (2000)","types":"Adventure|Comedy|Crime"},{"pid":4226,"name":"Memento (2000)","types":"Crime|Drama|Mystery|Thriller"},{"pid":5481,"name":"Austin Powers in Goldmember (2002)","types":"Comedy"},{"pid":5502,"name":"Signs (2002)","types":"Sci-Fi|Thriller"}]
访问http://localhost:8080/recommendByItem?userId=5&itemId=231&num=10, 给id为5的用户推荐10个与231相似的商品,输出:
[AppleScript] 纯文本查看 复制代码
[{"pid":253,"name":"Interview with the Vampire: The Vampire Chronicles (1994)","types":"Drama|Horror"},{"pid":592,"name":"Batman (1989)","types":"Action|Crime|Sci-Fi|Thriller"}]
4.总结
​        大家可以看得出来,用mahout这个算法库来做推荐系统写起来很简单,大致使用过程有:
步骤1:创建DataModel模型,可以基于文件File的DataModel,也可基于数据库的JDBCDataModel,如果数据库中表数据比较多,推荐耗时非常非常的慢,一般来说数据量都比较大可以基于文件DataModel模型来推荐,也可以将文件上传到hadoop,使用hadoop进行mapreduce计算,提高运算性能。        
步骤2:采用欧几里得、皮尔逊等算法计算相似度。
步骤3:构建推荐器,基于用户或基于内容进行推荐。
步骤4:将推荐出来的商品id补全其他数据返回给用户展示。

1 个回复

倒序浏览
您需要登录后才可以回帖 登录 | 加入黑马