博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习实战:第九章 树回归
阅读量:5022 次
发布时间:2019-06-12

本文共 6360 字,大约阅读时间需要 21 分钟。

源代码中有两处错误,在网上查找后解决。

 

from numpy import *import matplotlib.pyplot as pltdef loadDataSet(fileName):    dataSet = []    fr=open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')        fltLine=map(float,curLine)        dataSet.append(list(fltLine))    return dataSetdef binSplitDataSet(dataSet, feature, value):    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]     ##  第一处错误    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]    return mat0,mat1def regLeaf(dataSet):#returns the value used for each leaf    return mean(dataSet[:, -1])def regErr(dataSet):    return var(dataSet[:,-1]) * shape(dataSet)[0]def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]    tolN = ops[1]    #if all the target variables are the same value: quit and return value    if len(set(dataSet[:, -1].T.tolist()[0])) == 1: #exit cond 1        return None , leafType(dataSet)    m,n = shape(dataSet)    #the choice of the best feature is driven by Reduction in RSS error from mean    S = errType(dataSet)    bestS = inf; bestIndex = 0; bestValue = 0    for featIndex in range(n-1):        #for splitVal in set(dataSet[:,featIndex]):         ## 第二处错误        for splitVal in set((dataSet[:, featIndex].T.tolist())[0]):                                    mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                bestIndex = featIndex                bestValue = splitVal                bestS = newS    #if the decrease (S-bestS) is less than a threshold don't do the split    if (S - bestS) < tolS:        return None, leafType(dataSet) #exit cond 2    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3        return None, leafType(dataSet)    return bestIndex,bestValue#returns the best feature to split on                              #and the value used for that splitdef createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split    if feat == None: return val #if the splitting hit a stop condition return val    retTree = {}    retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val)    retTree['left'] = createTree(lSet, leafType, errType, ops)    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTree############################################  example construct a simple regression treemyDat1 = loadDataSet(r'train.txt')myMat = mat(myDat1)mytree = createTree(myMat)mytree=createTree(myMat,ops=(1,4))print(mytree)x=[];y=[]for a in myDat1:    x.append(a[:][-2])    y.append(a[:][-1])plt.scatter(x,y)plt.show()############################## cut some branchesdef isTree(obj):    return (type(obj).__name__=='dict')def getMean(tree):    if isTree(tree['right']): tree['right'] = getMean(tree['right'])    if isTree(tree['left']): tree['left'] = getMean(tree['left'])    return (tree['left']+tree['right'])/2.0def prune(tree, testData):    if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)    #if they are now both leafs, see if we can merge them    if not isTree(tree['left']) and not isTree(tree['right']):        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\            sum(power(rSet[:,-1] - tree['right'],2))        treeMean = (tree['left']+tree['right'])/2.0        errorMerge = sum(power(testData[:,-1] - treeMean,2))        if errorMerge < errorNoMerge:            print ("merging")            return treeMean        else: return tree    else: return treemydatTest=loadDataSet('test.txt')mymatTest=mat(mydatTest)cut_tree=prune(mytree,mymatTest)print(cut_tree)#################################model treedef linearSolve(dataSet):   #helper function used in two places    m,n = shape(dataSet)    X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y    xTx = X.T*X    if linalg.det(xTx) == 0.0:        raise NameError('This matrix is singular, cannot do inverse,\n\        try increasing the second value of ops')    ws = xTx.I * (X.T * Y)    return ws,X,Ydef modelLeaf(dataSet):    ws,X,Y =linearSolve(dataSet)    return wsdef modelErr(dataSet):    ws,X,Y = linearSolve(dataSet)    yHat = X* ws    return sum(power(Y - yHat,2))def regTreeEval(model, inDat):    return float(model)def modelTreeEval(model, inDat):    n = shape(inDat)[1]    X = mat(ones((1,n+1)))    X[:,1:n+1]=inDat    return float(X*model)def treeForeCast(tree, inData, modelEval=regTreeEval):    if not isTree(tree): return modelEval(tree, inData)    if inData[tree['spInd']] > tree['spVal']:        if isTree(tree['left']):            return treeForeCast(tree['left'], inData, modelEval)        else:            return modelEval(tree['left'], inData)    else:        if isTree(tree['right']):            return treeForeCast(tree['right'], inData, modelEval)        else:            return modelEval(tree['right'], inData)def createForeCast(tree, testData, modelEval=regTreeEval):    m = len(testData)    yHat = mat(zeros((m, 1)))    for i in range(m):        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)    return yHattrainmat=mat(loadDataSet('train.txt'))testdat=loadDataSet('test.txt')testmat=mat(testdat)print(testmat[:,1])mytree=createTree(trainmat,ops=(1,4))yHat=createForeCast(mytree,testmat[:,0])#print(yHat)co1=corrcoef(yHat,testmat[:,1],rowvar=0)[0,1]print(co1)#mytree=createTree(trainmat,modelLeaf,modelErr,(0,20))#yHat=createForeCast(mytree,testmat[:,0],modelTreeEval)#co2=corrcoef(yHat,testmat[:,1],rowvar=0)[0,1]#print(co2)

 

转载于:https://www.cnblogs.com/heifengli/p/7642634.html

你可能感兴趣的文章
error: 'Can't connect to local MySQL server through socket '/var/run/mysqld/mysqld.sock' (2)'
查看>>
手动配置三大框架整合:Spring+Struts2+mybatis
查看>>
开博了
查看>>
利用Jenkins自动部署工具间接构建kettle的调度平台
查看>>
关于 '0' === 0 浅析
查看>>
初始化mysql数据库时提示字符编码错误的解决办法
查看>>
python+selenium商城UI自动化
查看>>
使用参数和接收表单数据
查看>>
Android学习小记
查看>>
UML类图解析
查看>>
七牛 js 上传 解决没有文件名
查看>>
【iOS】设备系统版本
查看>>
java中的IO操作总结(三)
查看>>
onCreate中的savedInstanceState有何具体作用
查看>>
Caffe : Layer Catalogue(1)
查看>>
硬件(MAC)地址的概念及作用
查看>>
mybatis使用序列批量插入数据
查看>>
Java线程-- 线程池
查看>>
适时放手,是对自己的尊重
查看>>
badboy录制兼容性有趣测试
查看>>