A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

基于之前的C++版本的Python版本代码

import numpy as np
import random


def read_data(file='data.txt'):
    with open(file, 'r') as f:
        N = int(f.readline().replace('\n', ''))
        Mat = [0] * 10
        for i in range(N):
            Mat = list(map(int, f.readline().replace('\n', '').split(' ')))
    return N, Mat


def calpathValue(path):
    global Mat
    temp = Mat[0][path[0]]
    for i in range(len(path) - 1):
        temp += Mat[path][path[i + 1]]
    temp += Mat[path[-1]][0]
    return temp


def initial():
    global N
    init = list(range(1, N, 1))
    pack = [0] * LEN
    packValue = [0] * LEN
    for i in range(LEN):
        random.shuffle(init)
        data = init
        pack = data.copy()
        packValue = calpathValue(pack)
    indexes = np.argsort(packValue)
    pack = np.array(pack)[indexes]
    packValue = np.sort(packValue)
    return packValue, pack


# i: pack
def preserve(i):
    global tempPack, tempPackValue, pack, packValue, tempIndex
    tempPackValue[tempIndex] = packValue
    tempPack[tempIndex] = pack.copy()
    tempIndex += 1


def select():
    global N, pack, tempPack, tempPackValue, tempIndex, LEN, packValue

    tpk = tempPack[:tempIndex]
    tpkv = tempPackValue[:tempIndex]

    indexes = np.argsort(tpkv)
    tpk = np.array(tpk)[indexes]
    tpkv = np.sort(tpkv)

    pack = tpk[:LEN]
    packValue = tpkv[:LEN]


def crossover(i, j):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    tempPack[tempIndex] = pack.copy()
    pack_j_reindex = pack[j].copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_j_reindex:
            tempPack[tempIndex][v] = pack_j_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1
    tempPack[tempIndex] = pack[j].copy()
    pack_i_reindex = pack.copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_i_reindex:
            tempPack[tempIndex][v] = pack_i_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1


def mutation(i):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    origin_indexes = indexes.copy()
    random.shuffle(indexes)
    tempPack[tempIndex] = pack.copy()

    for t in range(times):
        tempPack[tempIndex][indexes[t]] = pack[origin_indexes[t]]
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])
    tempIndex += 1


if __name__ == '__main__':
    N, Mat = read_data()

    LEN = 25
    pc, pm = 0.7, 0.97
    NOTMORETHANstayINGV = 10

    packValue, pack = initial()

    tempLEN = LEN * LEN
    tempPack = [[0] * N] * tempLEN
    tempPackValue = [0] * tempLEN

    tempIndex = 0

    global_Value = packValue[0]
    stayinGV = 0

    while True:
        tempIndex = 0
        for i in range(LEN):
            preserve(i)
            if random.random() < pm:
                mutation(i)
            if i == LEN - 1: break
            for j in range(i + 1, LEN):
                if tempIndex >= tempLEN: break
                if random.random() < pc:
                    crossover(i, j)
        select()
        if packValue[0] < global_Value:
            global_Value = packValue[0]
            stayinGV = 0
        elif packValue[0] == global_Value:
            stayinGV += 1
        else:
            print("Something wrong")
            break
        if stayinGV == NOTMORETHANstayINGV:
            break

    print(global_Value)
    print(0, end='-->')
    for i in pack[0]:
        print(i, end='-->')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
所用的数据:
表示是的10个点之间的距离矩阵
10
0 58 82 89 17 50 26 48 70 19
58 0 74 46 70 2 70 49 87 60
82 74 0 58 76 98 37 97 34 67
89 46 58 0 15 17 28 69 46 79
17 70 76 15 0 98 60 69 97 89
50 2 98 17 98 0 81 14 43 47
26 70 37 28 60 81 0 43 73 56
48 49 97 69 69 14 43 0 39 0
70 87 34 46 97 43 73 39 0 53
19 60 67 79 89 47 56 0 53 0
1
2
3
4
5
6
7
8
9
10
11
随机生成数据并展示动图

黑色的点是起点。
我在知乎上看到类似的图,就试着改了下自己的代码也实现了这个效果。生成gif是基于我之前写的代码


import numpy as np
import random
import matplotlib.pyplot as plt
import os
import shutil
import imageio

def create_data(N, xu=10, yu=10, xd=-10, yd=-10):
    fx = lambda: random.random() * (xu - xd) + xd
    fy = lambda: random.random() * (yu - yd) + yd
    calDistance = lambda x, y: np.sqrt((x[0] - y[0]) ** 2 + (x[1] - y[1]) ** 2)

    points = [(0, 0)] * N
    for i in range(N):
        points = (fx(), fy())
    Mat = np.zeros((N, N))
    for i in range(N):
        for j in range(i + 1, N):
            dv = calDistance(points, points[j])
            Mat[j], Mat[j] = dv, dv
    return points, Mat


def calpathValue(path):
    global Mat
    temp = Mat[0][path[0]]
    for i in range(len(path) - 1):
        temp += Mat[path][path[i + 1]]
    temp += Mat[path[-1]][0]
    return temp


def initial():
    global N
    init = list(range(1, N, 1))
    pack = [0] * LEN
    packValue = [0] * LEN
    for i in range(LEN):
        random.shuffle(init)
        data = init
        pack = data.copy()
        packValue = calpathValue(pack)
    indexes = np.argsort(packValue)
    pack = np.array(pack)[indexes]
    packValue = np.sort(packValue)
    return packValue, pack


# i: pack
def preserve(i):
    global tempPack, tempPackValue, pack, packValue, tempIndex
    tempPackValue[tempIndex] = packValue
    tempPack[tempIndex] = pack.copy()
    tempIndex += 1


def select():
    global N, pack, tempPack, tempPackValue, tempIndex, LEN, packValue

    tpk = tempPack[:tempIndex]
    tpkv = tempPackValue[:tempIndex]

    indexes = np.argsort(tpkv)
    tpk = np.array(tpk)[indexes]
    tpkv = np.sort(tpkv)

    pack = tpk[:LEN]
    packValue = tpkv[:LEN]


def crossover(i, j):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    tempPack[tempIndex] = pack.copy()
    pack_j_reindex = pack[j].copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_j_reindex:
            tempPack[tempIndex][v] = pack_j_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1
    tempPack[tempIndex] = pack[j].copy()
    pack_i_reindex = pack.copy()[indexes]
    count = 0
    for v in range(N - 1):
        if count >= times: break
        if tempPack[tempIndex][v] in pack_i_reindex:
            tempPack[tempIndex][v] = pack_i_reindex[count]
            count += 1
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])

    tempIndex += 1


def mutation(i):
    global N, pack, tempPack, tempPackValue, tempIndex
    times = random.randint(1, N - 2)
    indexes = [0] * times
    for t in range(times):
        if t == 0:
            indexes[t] = random.randint(0, N - times - 1)
        else:
            indexes[t] = random.randint(indexes[t - 1] + 1, N - times + t - 1)
    origin_indexes = indexes.copy()
    random.shuffle(indexes)
    tempPack[tempIndex] = pack.copy()

    for t in range(times):
        tempPack[tempIndex][indexes[t]] = pack[origin_indexes[t]]
    tempPackValue[tempIndex] = calpathValue(tempPack[tempIndex])
    tempIndex += 1


def draw(path, pv):
    global points, N, TIMESIT, PNGFILE, PNGLIST
    plt.cla()
    plt.title('cross=%.4f' % pv)
    xs = [p[0] for p in points]
    ys = [p[1] for p in points]
    plt.scatter(xs, ys, color='b')
    xs = np.array(xs)
    ys = np.array(ys)
    plt.plot(xs[[0, path[0]]], ys[[0, path[0]]], color='r')
    for i in range(N - 2):
        plt.plot(xs[[path, path[i + 1]]], ys[[path, path[i + 1]]], color='r')
    plt.plot(xs[[path[N - 2], 0]], ys[[path[N - 2], 0]], color='r')
    plt.scatter(xs[0], ys[0], color='k', linewidth=10)
    plt.savefig('%s/%d.png' % (PNGFILE, TIMESIT))
    PNGLIST.append('%s/%d.png' % (PNGFILE, TIMESIT))
    TIMESIT += 1


if __name__ == '__main__':
    # N, Mat = read_data()
    TIMESIT = 0
    PNGFILE = './png/'
    PNGLIST = []
    if not os.path.exists(PNGFILE):
        os.mkdir(PNGFILE)
    else:
        shutil.rmtree(PNGFILE)
        os.mkdir(PNGFILE)

    N = 20
    points, Mat = create_data(N)
    LEN = 40
    pc, pm = 0.7, 0.97
    NOTMORETHANstayINGV = 30

    packValue, pack = initial()

    tempLEN = LEN * LEN
    tempPack = [[0] * N] * tempLEN
    tempPackValue = [0] * tempLEN

    tempIndex = 0

    global_Value = packValue[0]
    draw(pack[0], global_Value)
    stayinGV = 0

    while True:
        tempIndex = 0
        for i in range(LEN):
            preserve(i)
            if random.random() < pm:
                mutation(i)
            if i == LEN - 1: break
            for j in range(i + 1, LEN):
                if tempIndex >= tempLEN: break
                if random.random() < pc:
                    crossover(i, j)
        select()
        if packValue[0] < global_Value:
            global_Value = packValue[0]
            draw(pack[0], global_Value)
            stayinGV = 0
        elif packValue[0] == global_Value:
            stayinGV += 1
        else:
            print("Something wrong")
            break
        if stayinGV == NOTMORETHANstayINGV:
            break

    print(global_Value)
    print(0, end='-->')
    for i in pack[0]:
        print(i, end='-->')

    generated_images = []
    for png_path in PNGLIST:
        generated_images.append(imageio.imread(png_path))
    shutil.rmtree(PNGFILE)  # 可删掉
    generated_images = generated_images + [generated_images[-1]] * 5
    imageio.mimsave('TSP-GA.gif', generated_images, 'GIF', duration=0.5)
---------------------
【转载】仅作分享,侵删
作者:肥宅_Sean
原文:https://blog.csdn.net/a19990412/article/details/84978612


1 个回复

倒序浏览
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马