【机器学习】 [代码篇] 30. KNN - sklearn 以及 自定义KNN 的实现

news/2025/2/27 6:37:43

KNN - sklearn 以及 自定义KNN 的实现

  • 前言
  • Github 链接
  • 使用SKlearn 库完成KNN的训练以及预测
    • 1. 导入需要的库
    • 2. 加载数据
      • 2.1. 输出数据信息
    • 3. 分割训练集和测试集
    • 4. 可视化
    • 5. 创建模型并预测
  • 2. 自定义KNN模型并预测

前言

前面写完了理论篇,接下来补充代码。

机器学习使用sklearn会很简单,因此重点看下如何自定义实现。

KNN理论链接跳转

Github 链接

Github链接跳转

使用SKlearn 库完成KNN的训练以及预测

1. 导入需要的库

from IPython.display import set_matplotlib_formats, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载数据

from sklearn.datasets import load_iris
iris_dataset = load_iris()

2.1. 输出数据信息

print("Keys of iris_dataset:\n", iris_dataset.keys())
print(iris_dataset['DESCR'][:193] + "\n...")
print("Target names:", iris_dataset['target_names'])
print("Feature names:\n", iris_dataset['feature_names'])

3. 分割训练集和测试集

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    iris_dataset['data'], iris_dataset['target'], random_state=0)

4. 可视化

# label the columns using the strings in iris_dataset.feature_names
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# Create a scatter matrix from the dataframe, color by y_train
pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(16, 16),
                           marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8);

在这里插入图片描述

5. 创建模型并预测

from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
scaler = MinMaxScaler()#creating an object
scaler.fit(X_train)#calculate min and max value of the training data

X_train_norm = scaler.transform(X_train) #apply normalisation to the training set
X_test_norm = scaler.transform(X_test)

knn = KNeighborsClassifier(n_neighbors=40)
knn.fit(X_train_norm, y_train)
y_pred = knn.predict(X_test_norm) 
print("Accuracy on test set: {:.5f}".format(accuracy_score(y_pred, y_test)))

2. 自定义KNN模型并预测

import numpy as np
from collections import Counter

class KNN:
    def __init__(self, k=3, distance_metric='euclidean'):
        self.k = k
        self.distance_metric = distance_metric

    # define fit function
    def fit(self, X_train, y_train):
        self.X_train = np.array(X_train)
        self.y_train = np.array(y_train)

    # calculate distance
    def _compute_distance(self, x1, x2):
        if self.distance_metric == 'euclidean':
            return np.sqrt(np.sum((x1 - x2) ** 2))
        elif self.distance_metric == 'manhattan':
            return np.sum(np.abs(x1 - x2)) 
        else:
            raise ValueError("Unsupported distance metric")

    def predict(self, X_test):
        X_test = np.array(X_test)
        predictions = []

        for x in X_test:
            distances = [self._compute_distance(x, x_train) for x_train in self.X_train] # calculate all distance
            k_indices = np.argsort(distances)[:self.k]  # find the neaset  k points 
            k_nearest_labels = [self.y_train[i] for i in k_indices]  
            most_common = Counter(k_nearest_labels).most_common(1)[0][0]  # get the most common class
            predictions.append(most_common)

        return np.array(predictions)
    
    def score(self, X_test, y_test):
        y_pred = self.predict(X_test)
        return np.mean(y_pred == np.array(y_test))  # score

knn = KNN(k=40)
knn.fit(X_train_norm, y_train)
predictions = knn.predict(X_test_norm)
accuracy = knn.score(X_test_norm, y_test)

print(f"Predictions: {predictions}")
print(f"Accuracy: {accuracy}")

http://www.niftyadmin.cn/n/5869669.html

相关文章

应对现代生活的健康养生指南

在科技飞速发展的现代社会,人们的生活方式发生了巨大改变,随之而来的是一系列健康问题。快节奏的生活、高强度的工作以及电子产品的过度使用,让我们的身体承受着前所未有的压力。因此,掌握正确的健康养生方法迫在眉睫。 针对久坐不…

PCL源码分析:曲面法向量采样

文章目录 一、简介二、源码分析三、实现效果参考资料一、简介 曲面法向量点云采样,整个过程如下所述: 1、空间划分:使用递归方法将点云划分为更小的区域, 每次划分选择一个维度(X、Y 或 Z),将点云分为两部分,直到划分区域内的点少于我们指定的数量,开始进行区域随机采…

传递指针给函数的用法

在 C 语言中,将指针传递给函数是一种常见且重要的编程技巧,它可以让函数直接操作调用者提供的内存区域,实现数据的修改、避免数据的复制开销等。下面为你提供几个不同场景下传递指针给函数的例子。 1. 修改调用者的变量值 通过传递变量的指针…

归纳总结一下Tensorflow、PaddlePaddle、Pytorch构建神经网络基本流程,以及使用NCNN推理的流程

使用Tensorflow构建神经网络,这里使用keras API,采用Sequential方式快速构建 import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.utils import to_categorical# 加载数据集 (train_images, train_labels), (t…

9、什么是野指针?如何避免?【中高频】

(1)什么是野指针 野指针是 一种未被初始化的指针,通常会指向一个随机的内存地址。这个地址不可预测的,所以可能会导致 程序和数据出现错误 (2)在什么情况下会产生野指针? 初始化指针时&#xf…

【K8S】Kubernetes 基本架构、节点类型及运行流程详解(附架构图及流程图)

Kubernetes 架构 k8s 集群 多个 master node 多个 work nodeMaster 节点(主节点):负责集群的管理任务,包括调度容器、维护集群状态、监控集群、管理服务发现等。Worker 节点(工作节点):实际运…

使用ZFile打造属于自己的私有云系统结合内网穿透实现安全远程访问

文章目录 前言1.关于ZFile2.本地部署ZFile3.ZFile本地访问测试4.ZFile的配置5.cpolar内网穿透工具安装6.创建远程连接公网地址7.固定ZFile公网地址 前言 在数字化的今天,我们每个人都是信息的小能手。无论是职场高手、摄影达人还是学习狂人,每天都在创造…

Google sheet 复制excel内容自动合并单元格问题

解决路径:file-import-upload 这样上传本地的excel源文件,就没有这个问题了