当前位置:Gxlcms > 数据库问题 > Python实现DBScan

Python实现DBScan

时间:2021-07-01 10:21:17 帮助过:89人阅读

from:https://www.cnblogs.com/wsine/p/5180778.html

运行环境

  • Pyhton3
  • numpy(科学计算包)
  • matplotlib(画图所需,不画图可不必)

计算过程

st=>start: 开始
e=>end: 结束
op1=>operation: 读入数据
cond=>condition: 是否还有未分类数据
op2=>operation: 找一未分类点扩散
op3=>operation: 输出结果

st->op1->op2->cond
cond(yes)->op2
cond(no)->op3->e

输入样例

/* 788points.txt */
15.55,28.65
14.9,27.55
14.45,28.35
14.15,28.8
13.75,28.05
13.35,28.45
13,29.15
13.45,27.5
13.6,26.5
12.8,27.35
12.4,27.85
12.3,28.4
12.2,28.65
13.4,25.1
12.95,25.95

788points.txt完整文件:下载

代码实现

# -*- coding: utf-8 -*-
__author__ = ‘Wsine‘

import numpy as np
import matplotlib.pyplot as plt
import math
import time

UNCLASSIFIED = False
NOISE = 0

def loadDataSet(fileName, splitChar=‘\t‘):
    """
    输入:文件名
    输出:数据集
    描述:从文件读入数据集
    """
    dataSet = []
    with open(fileName) as fr:
        for line in fr.readlines():
            curline = line.strip().split(splitChar)
            fltline = list(map(float, curline))
            dataSet.append(fltline)
    return dataSet

def dist(a, b):
    """
    输入:向量A, 向量B
    输出:两个向量的欧式距离
    """
    return math.sqrt(np.power(a - b, 2).sum())

def eps_neighbor(a, b, eps):
    """
    输入:向量A, 向量B
    输出:是否在eps范围内
    """
    return dist(a, b) < eps

def region_query(data, pointId, eps):
    """
    输入:数据集, 查询点id, 半径大小
    输出:在eps范围内的点的id
    """
    nPoints = data.shape[1]
    seeds = []
    for i in range(nPoints):
        if eps_neighbor(data[:, pointId], data[:, i], eps):
            seeds.append(i)
    return seeds

def expand_cluster(data, clusterResult, pointId, clusterId, eps, minPts):
    """
    输入:数据集, 分类结果, 待分类点id, 簇id, 半径大小, 最小点个数
    输出:能否成功分类
    """
    seeds = region_query(data, pointId, eps)
    if len(seeds) < minPts: # 不满足minPts条件的为噪声点
        clusterResult[pointId] = NOISE
        return False
    else:
        clusterResult[pointId] = clusterId # 划分到该簇
        for seedId in seeds:
            clusterResult[seedId] = clusterId

        while len(seeds) > 0: # 持续扩张
            currentPoint = seeds[0]
            queryResults = region_query(data, currentPoint, eps)
            if len(queryResults) >= minPts:
                for i in range(len(queryResults)):
                    resultPoint = queryResults[i]
                    if clusterResult[resultPoint] == UNCLASSIFIED:
                        seeds.append(resultPoint)
                        clusterResult[resultPoint] = clusterId
                    elif clusterResult[resultPoint] == NOISE:
                        clusterResult[resultPoint] = clusterId
            seeds = seeds[1:]
        return True

def dbscan(data, eps, minPts):
    """
    输入:数据集, 半径大小, 最小点个数
    输出:分类簇id
    """
    clusterId = 1
    nPoints = data.shape[1]
    clusterResult = [UNCLASSIFIED] * nPoints
    for pointId in range(nPoints):
        point = data[:, pointId]
        if clusterResult[pointId] == UNCLASSIFIED:
            if expand_cluster(data, clusterResult, pointId, clusterId, eps, minPts):
                clusterId = clusterId + 1
    return clusterResult, clusterId - 1

def plotFeature(data, clusters, clusterNum):
    nPoints = data.shape[1]
    matClusters = np.mat(clusters).transpose()
    fig = plt.figure()
    scatterColors = [‘black‘, ‘blue‘, ‘green‘, ‘yellow‘, ‘red‘, ‘purple‘, ‘orange‘, ‘brown‘]
    ax = fig.add_subplot(111)
    for i in range(clusterNum + 1):
        colorSytle = scatterColors[i % len(scatterColors)]
        subCluster = data[:, np.nonzero(matClusters[:, 0].A == i)]
        ax.scatter(subCluster[0, :].flatten().A[0], subCluster[1, :].flatten().A[0], c=colorSytle, s=50)

def main():
    dataSet = loadDataSet(‘788points.txt‘, splitChar=‘,‘)
    dataSet = np.mat(dataSet).transpose()
    # print(dataSet)
    clusters, clusterNum = dbscan(dataSet, 2, 15)
    print("cluster Numbers = ", clusterNum)
    # print(clusters)
    plotFeature(dataSet, clusters, clusterNum)

if __name__ == ‘__main__‘:
    start = time.clock()
    main()
    end = time.clock()
    print(‘finish all in %s‘ % str(end - start))
    plt.show()

输出样例

cluster Numbers =  7
finish all in 32.712135628590794

 技术分享图片

 

Python实现DBScan

标签:ber   html   image   距离   scan   baidu   utf-8   region   下载   

人气教程排行