黑马程序员技术交流社区

标题: 【上海校区】遗传算法解决TSP问题 Python实现【160行以内代... [打印本页]

作者: 不二晨    时间: 2018-12-24 11:20
标题: 【上海校区】遗传算法解决TSP问题 Python实现【160行以内代...
基于之前的C++版本的Python版本代码

import numpy as np
import random


def read_data(file='data.txt'):
    with open(file, 'r') as f:
        N = int(f.readline().replace('\n', ''))
        Mat = [0] * 10
        for i in range(N):
            Mat = list(map(int, f.readline().replace('\n', '').split(' ')))
    return N, Mat


def calpathValue(path):
    global Mat
    temp = Mat[0][path[0]]
    for i in range(len(path) - 1):
        temp += Mat[path][path[i + 1]]
    temp += Mat[path[-1]][0]
    return temp


def initial():
    global N
    init = list(range(1, N, 1))
    pack = [0] * LEN
    packValue = [0] * LEN
    for i in range(LEN):
        random.shuffle(init)
        data = init
        pack = data.copy()
        packValue = calpathValue(pack)
    indexes = np.argsort(packValue)
    pack = np.array(pack)[indexes]
    packValue = np.sort(packValue)
    return packValue, pack


# i: pack
def preserve(i):
    global tempPack, tempPackValue, pack, packValue, tempIndex
    tempPackValue[tempIndex] = packValue
    tempPack[tempIndex] = pack.copy()
    tempIndex += 1


def select():
    global N, pack, tempPack, tempPackValue, tempIndex, LEN, packValue

    tpk = tempPack[:tempIndex]
    tpkv = tempPackValue[:tempIndex]

    indexes = np.argsort(tpkv)
    tpk = np.array(tpk)[indexes]
    tpkv = np.sort(tpkv)

    pack = tpk[:LEN]
    packValue = tpkv[:LEN]


def crossover(i, j):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    tempPack[tempIndex] = pack.copy()
    pack_j_reindex = pack[j].copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_j_reindex:
            tempPack[tempIndex][v] = pack_j_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1
    tempPack[tempIndex] = pack[j].copy()
    pack_i_reindex = pack.copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_i_reindex:
            tempPack[tempIndex][v] = pack_i_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1


def mutation(i):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    origin_indexes = indexes.copy()
    random.shuffle(indexes)
    tempPack[tempIndex] = pack.copy()

    for t in range(times):
        tempPack[tempIndex][indexes[t]] = pack[origin_indexes[t]]
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])
    tempIndex += 1


if __name__ == '__main__':
    N, Mat = read_data()

    LEN = 25
    pc, pm = 0.7, 0.97
    NOTMORETHANstayINGV = 10

    packValue, pack = initial()

    tempLEN = LEN * LEN
    tempPack = [[0] * N] * tempLEN
    tempPackValue = [0] * tempLEN

    tempIndex = 0

    global_Value = packValue[0]
    stayinGV = 0

    while True:
        tempIndex = 0
        for i in range(LEN):
            preserve(i)
            if random.random() < pm:
                mutation(i)
            if i == LEN - 1: break
            for j in range(i + 1, LEN):
                if tempIndex >= tempLEN: break
                if random.random() < pc:
                    crossover(i, j)
        select()
        if packValue[0] < global_Value:
            global_Value = packValue[0]
            stayinGV = 0
        elif packValue[0] == global_Value:
            stayinGV += 1
        else:
            print("Something wrong")
            break
        if stayinGV == NOTMORETHANstayINGV:
            break

    print(global_Value)
    print(0, end='-->')
    for i in pack[0]:
        print(i, end='-->')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
所用的数据:
表示是的10个点之间的距离矩阵
10
0 58 82 89 17 50 26 48 70 19
58 0 74 46 70 2 70 49 87 60
82 74 0 58 76 98 37 97 34 67
89 46 58 0 15 17 28 69 46 79
17 70 76 15 0 98 60 69 97 89
50 2 98 17 98 0 81 14 43 47
26 70 37 28 60 81 0 43 73 56
48 49 97 69 69 14 43 0 39 0
70 87 34 46 97 43 73 39 0 53
19 60 67 79 89 47 56 0 53 0
1
2
3
4
5
6
7
8
9
10
11
随机生成数据并展示动图

黑色的点是起点。
我在知乎上看到类似的图,就试着改了下自己的代码也实现了这个效果。生成gif是基于我之前写的代码


import numpy as np
import random
import matplotlib.pyplot as plt
import os
import shutil
import imageio

def create_data(N, xu=10, yu=10, xd=-10, yd=-10):
    fx = lambda: random.random() * (xu - xd) + xd
    fy = lambda: random.random() * (yu - yd) + yd
    calDistance = lambda x, y: np.sqrt((x[0] - y[0]) ** 2 + (x[1] - y[1]) ** 2)

    points = [(0, 0)] * N
    for i in range(N):
        points = (fx(), fy())
    Mat = np.zeros((N, N))
    for i in range(N):
        for j in range(i + 1, N):
            dv = calDistance(points, points[j])
            Mat[j], Mat[j] = dv, dv
    return points, Mat


def calpathValue(path):
    global Mat
    temp = Mat[0][path[0]]
    for i in range(len(path) - 1):
        temp += Mat[path][path[i + 1]]
    temp += Mat[path[-1]][0]
    return temp


def initial():
    global N
    init = list(range(1, N, 1))
    pack = [0] * LEN
    packValue = [0] * LEN
    for i in range(LEN):
        random.shuffle(init)
        data = init
        pack = data.copy()
        packValue = calpathValue(pack)
    indexes = np.argsort(packValue)
    pack = np.array(pack)[indexes]
    packValue = np.sort(packValue)
    return packValue, pack


# i: pack
def preserve(i):
    global tempPack, tempPackValue, pack, packValue, tempIndex
    tempPackValue[tempIndex] = packValue
    tempPack[tempIndex] = pack.copy()
    tempIndex += 1


def select():
    global N, pack, tempPack, tempPackValue, tempIndex, LEN, packValue

    tpk = tempPack[:tempIndex]
    tpkv = tempPackValue[:tempIndex]

    indexes = np.argsort(tpkv)
    tpk = np.array(tpk)[indexes]
    tpkv = np.sort(tpkv)

    pack = tpk[:LEN]
    packValue = tpkv[:LEN]


def crossover(i, j):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    tempPack[tempIndex] = pack.copy()
    pack_j_reindex = pack[j].copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_j_reindex:
            tempPack[tempIndex][v] = pack_j_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1
    tempPack[tempIndex] = pack[j].copy()
    pack_i_reindex = pack.copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_i_reindex:
            tempPack[tempIndex][v] = pack_i_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1


def mutation(i):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    origin_indexes = indexes.copy()
    random.shuffle(indexes)
    tempPack[tempIndex] = pack.copy()

    for t in range(times):
        tempPack[tempIndex][indexes[t]] = pack[origin_indexes[t]]
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])
    tempIndex += 1


def draw(path, pv):
    global points, N, TIMESIT, PNGFILE, PNGLIST
    plt.cla()
    plt.title('cross=%.4f' % pv)
    xs = [p[0] for p in points]
    ys = [p[1] for p in points]
    plt.scatter(xs, ys, color='b')
    xs = np.array(xs)
    ys = np.array(ys)
    plt.plot(xs[[0, path[0]]], ys[[0, path[0]]], color='r')
    for i in range(N - 2):
        plt.plot(xs[[path, path[i + 1]]], ys[[path, path[i + 1]]], color='r')
    plt.plot(xs[[path[N - 2], 0]], ys[[path[N - 2], 0]], color='r')
    plt.scatter(xs[0], ys[0], color='k', linewidth=10)
    plt.savefig('%s/%d.png' % (PNGFILE, TIMESIT))
    PNGLIST.append('%s/%d.png' % (PNGFILE, TIMESIT))
    TIMESIT += 1


if __name__ == '__main__':
    # N, Mat = read_data()
    TIMESIT = 0
    PNGFILE = './png/'
    PNGLIST = []
    if not os.path.exists(PNGFILE):
        os.mkdir(PNGFILE)
    else:
        shutil.rmtree(PNGFILE)
        os.mkdir(PNGFILE)

    N = 20
    points, Mat = create_data(N)
    LEN = 40
    pc, pm = 0.7, 0.97
    NOTMORETHANstayINGV = 30

    packValue, pack = initial()

    tempLEN = LEN * LEN
    tempPack = [[0] * N] * tempLEN
    tempPackValue = [0] * tempLEN

    tempIndex = 0

    global_Value = packValue[0]
    draw(pack[0], global_Value)
    stayinGV = 0

    while True:
        tempIndex = 0
        for i in range(LEN):
            preserve(i)
            if random.random() < pm:
                mutation(i)
            if i == LEN - 1: break
            for j in range(i + 1, LEN):
                if tempIndex >= tempLEN: break
                if random.random() < pc:
                    crossover(i, j)
        select()
        if packValue[0] < global_Value:
            global_Value = packValue[0]
            draw(pack[0], global_Value)
            stayinGV = 0
        elif packValue[0] == global_Value:
            stayinGV += 1
        else:
            print("Something wrong")
            break
        if stayinGV == NOTMORETHANstayINGV:
            break

    print(global_Value)
    print(0, end='-->')
    for i in pack[0]:
        print(i, end='-->')

    generated_images = []
    for png_path in PNGLIST:
        generated_images.append(imageio.imread(png_path))
    shutil.rmtree(PNGFILE)  # 可删掉
    generated_images = generated_images + [generated_images[-1]] * 5
    imageio.mimsave('TSP-GA.gif', generated_images, 'GIF', duration=0.5)
---------------------
【转载】仅作分享,侵删
作者:肥宅_Sean
原文:https://blog.csdn.net/a19990412/article/details/84978612



作者: 不二晨    时间: 2018-12-26 10:13





欢迎光临 黑马程序员技术交流社区 (http://bbs.itheima.com/) 黑马程序员IT技术论坛 X3.2