黑马程序员技术交流社区

标题: 模拟Mybatis逆向工程创建javaBean类:使用数据库的表来分别... [打印本页]

作者: yellowstar    时间: 2019-9-12 15:24
标题: 模拟Mybatis逆向工程创建javaBean类:使用数据库的表来分别...
之前学习过Mybatis框架,记忆最深的还是逆向工程,模拟这个实现一些简单的小功能

思路:查询数据库所有表的结构,封装成对象遍历,来无脑拼出bean类。

主要功能:

startTables()数据库下所有表创建bean类
startTable(String tablename)根据指定表创建bean类

测试类:

运行之后会在bean包下创建相应的Bean类



以下为实现代码
Table类:
package com.field;

public class Table {
    private String Field;
    private String Type;
    private String Null;
    private String Key;
    private String Default;
    private String Extra;

    public Table() {
    }

    public Table(String field, String type, String aNull, String key, String aDefault, String extra) {
        Field = field;
        Type = type;
        Null = aNull;
        Key = key;
        Default = aDefault;
        Extra = extra;
    }

    public String getField() {
        return Field;
    }

    public void setField(String field) {
        Field = field;
    }

    public String getType() {
        return Type;
    }

    public void setType(String type) {
        Type = type;
    }

    public String getNull() {
        return Null;
    }

    public void setNull(String aNull) {
        Null = aNull;
    }

    public String getKey() {
        return Key;
    }

    public void setKey(String key) {
        Key = key;
    }

    public String getDefault() {
        return Default;
    }

    public void setDefault(String aDefault) {
        Default = aDefault;
    }

    public String getExtra() {
        return Extra;
    }

    public void setExtra(String extra) {
        Extra = extra;
    }
}
Reverse类:
package com.util;



import com.field.Table;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

public class Reverse {
    private DBCon dbc;

    //创建对象时初始画一个数据连接
    public Reverse() {
        dbc = new DBCon();
    }

    /**
     * 获取对应数据库的所有表
     */
    private List<String> getTables() {
        ArrayList<String> list = new ArrayList<>();
        String sql = "show tables";
        ResultSet rs = dbc.doQuery(sql, new Object[]{});
        try {
            while (rs.next()) {
                list.add(rs.getString(1));
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return list;
    }

    /**
     * 判断输入的表是否存在
     *
     * @param tablename
     * @return
     */
    private boolean isTableExists(String tablename) {
        List<String> list = getTables();
        return list.contains(tablename);

    }

    /**
     * 封装存储表的字段
     *
     * @param tablename
     * @return
     */
    private ArrayList<Table> getDesc(String tablename) {
        if (isTableExists(tablename)) {
            String sql = "desc " + tablename;
            ResultSet rs = dbc.doQuery(sql, new Object[]{});
            ArrayList<Table> list = new ArrayList<>();
            try {
                while (rs.next()) {
                    Table table = new Table();
                    table.setField(rs.getString("Field"));
                    table.setType(rs.getString("Type"));
                    table.setNull(rs.getString("Null"));
                    table.setKey(rs.getString("Key"));
                    table.setDefault(rs.getString("Default"));
                    table.setExtra(rs.getString("Extra"));
                    list.add(table);
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
            return list;
        }
        throw new RuntimeException("表不存在!");
    }

    /**
     * 创建对应数据库所有的表
     */
    public void startTables() {
        List<String> tables = getTables();
        for (String table : tables) {
            startTable(table);
        }
    }

    /**
     * 单个表的创建
     *
     * @param tablename
     */
    public void startTable(String tablename) {
        PrintWriter pw = null;
        try {
            //创建打印流,和指定打印路径
            pw = new PrintWriter(new FileWriter(getPath("MyHibernate03") + toUpper(tablename) + ".java"));
        } catch (IOException e) {
            e.printStackTrace();
        }
        StringBuilder sb = new StringBuilder();

        //包名注解
        pw.println("package com.bean;");
        pw.println();
        pw.println("import com.annoction.TableName;");
        pw.println("import com.annoction.Column;");
        pw.println("import com.annoction.Primarykey;");
        pw.println();
        pw.println("@TableName(\"" + tablename + "\")");
        pw.println("public class " + toUpper(tablename) + " {");
        ArrayList<Table> tables = getDesc(tablename);

        //成员变量和注解
        for (Table table : tables) {
            String type = changetype(table.getType());
            String filed = table.getField();
            if (getPrimaryKey(tablename).equals(filed)) {
                pw.print("\t");
                pw.println("@Primarykey(\"" + filed + "\")");
            }
            pw.print("\t");
            pw.println("@Column(\"" + filed + "\")");
            pw.print("\t");
            pw.println("private " + type + " " + filed + ";");
        }
        pw.println();
        //无参构造
        pw.println("public " + toUpper(tablename) + "(){}\r\n");
        pw.println();

        //全参构造
        pw.print("public " + toUpper(tablename) + "(");
        for (Table table : tables) {
            String field = table.getField();
            String type = changetype(table.getType());
            sb.append(type + " " + field + ",");
        }
        String substring = sb.substring(0, sb.length() - 1);
        pw.print(substring);
        pw.println(") {");
        for (Table table : tables) {
            String field = table.getField();
            pw.println("this." + field + "=" + field + ";");
        }
        pw.println("}");
        pw.println();

        //set/get方法
        for (Table table : tables) {
            String type = changetype(table.getType());
            String field = table.getField();
            pw.println("public void set" + toUpper(field) + "(" + type + " " + field + ") {");
            pw.println("this." + field + "=" + field + ";");
            pw.println("}");
            pw.println("public " + type + " get" + toUpper(field) + "(){");
            pw.println("return " + field + ";");
            pw.println("}");
            pw.println();
        }
        pw.println();

        //toString()
        pw.println("@Override");
        pw.println("public String toString(){");
        pw.println(" return \"" + toUpper(tablename) + "{\" +");
        sb = new StringBuilder();
        for (Table table : tables) {
            sb.append("  \"" + table.getField() + "=\" + " + table.getField() + "+\n" +
                    " \", \"+");
        }
        String str = sb.substring(0, sb.length() - 5).concat("\"}\"").concat(";");

        pw.println(str);
        pw.println("}");
        pw.println("}");
        pw.close();
    }

    /**
     * 拼接创建路径
     *
     * @return
     */
    private String getPath(String projectname) {
        String path = DBCon.class.getClassLoader().getResource("db.properties").getPath();
        String sub1 = path.substring(1, path.length() -29-projectname.length());
        String sub2=sub1.concat(projectname).concat("/src/com/bean/");
        File file = new File(sub2);
        if (!file.exists()) file.mkdirs();
        return sub2;
    }

    /**
     * 获取主键名称
     *
     * @param tablename
     * @return
     */
    private String getPrimaryKey(String tablename) {
        ArrayList<Table> desc = getDesc(tablename);
        for (Table table : desc) {
            if (table.getKey().equals("PRI")) {
                return table.getField();
            }

        }
        throw new RuntimeException("没有找到主键!");
    }

    /**
     * 类型转换
     *
     * @param type
     * @return
     */
    private String changetype(String type) {
        if (type.startsWith("int")) return "int";
        if (type.startsWith("varchar")) return "String";
//        if(type.startsWith(""))
        else return "String";
    }

    private String toUpper(String str) {
        return String.valueOf(Character.toUpperCase(str.charAt(0))).concat(str.substring(1));
    }

    /**
     * 释放资源
     */
    public void close() {
        dbc.close();
    }

}
db.properties(放在src下)
driver=com.mysql.cj.jdbc.Driver
url=jdbc:mysql://localhost:3306/myhibernate?serverTimezone=UTC
username=root
password=123456
database=myhibernate


DBCon类:
package com.util;

import java.io.IOException;
import java.io.InputStream;
import java.sql.*;
import java.util.Properties;

public class DBCon {
    private static Properties prop;
    private Connection conn;
    private ResultSet rs;
    private PreparedStatement pstmt;

    public DBCon() {
        this.conn = getConn();
    }

    /**
     * 加载配置文件
     */
    static {
        //获取配置文件的属性
        prop = new Properties();
        InputStream is = DBCon.class.getClassLoader().getResourceAsStream("db.properties");
        try {
            prop.load(is);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 获取数据库连接
     *
     * @return
     */
    public Connection getConn() {
        try {
            Class.forName(prop.getProperty("driver"));
            conn = DriverManager.getConnection(prop.getProperty("url"), prop.getProperty("username"), prop.getProperty("password"));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return conn;
    }

    /**
     * 封装查询
     *
     * @param sql
     * @param arrs
     * @return
     */
    public ResultSet doQuery(String sql, Object[] arrs) {
        //  rs = null;
        try {
            pstmt = conn.prepareStatement(sql);
            for (int i = 0; i < arrs.length; i++) {
                pstmt.setObject((i + 1), arrs[i]);
            }
            rs = pstmt.executeQuery();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return rs;
    }

    /**
     * 封装增删改
     *
     * @param sql
     * @param arrs
     * @return
     */
    public int doUpdate(String sql, Object[] arrs) {
        int res = 0;
        try {
            pstmt = conn.prepareStatement(sql);
            for (int i = 0; i < arrs.length; i++) {
                pstmt.setObject((i + 1), arrs[i]);
            }
            res = pstmt.executeUpdate();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return res;
    }

    /**
     * 开启事务
     */
    public void OpenTransaction() {
        try {
            //开启事务
            conn.setAutoCommit(false);
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

    /**
     * 提交事务
     */
    public void commit() {
        try {
            conn.commit();
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

    /**
     * 回滚事务
     */
    public void rollback() {
        try {
            conn.rollback();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    /**
     * 关闭资源
     */
    public void close() {
        if (rs != null) {
            try {
                rs.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        if (pstmt != null) {
            try {
                pstmt.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        if (conn != null) {
            try {
                conn.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

}







欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/) 黑马程序员IT技术论坛 X3.2