模型整体架构设计思路
- 数据预处理模块:
- 图像归一化:将手写数字图像的尺寸统一,例如调整为固定大小(如28x28像素),并将像素值归一化到[0, 1]范围,以减少不同图像间因尺寸和亮度差异带来的影响。
- 灰度化:如果图像是彩色的,将其转换为灰度图像,简化数据同时保留关键特征,因为手写数字识别主要关注形状,颜色信息作用不大。
- 特征提取层:
- 卷积神经网络(CNN):可使用简单的卷积层和池化层组合。卷积层通过卷积核在图像上滑动,提取局部特征,不同的卷积核可以捕捉不同类型的特征,如边缘、角点等。池化层则用于降低数据维度,减少计算量,同时保留主要特征,例如使用最大池化或平均池化。
- 分类层:
- 全连接层:将经过特征提取的特征向量通过全连接层,将其映射到输出类别空间,输出每个类别的概率。例如对于手写数字识别,输出10个类别的概率,对应数字0 - 9。
- Softmax激活函数:在全连接层输出后使用Softmax函数,将输出值转换为概率分布,使得所有类别的概率之和为1,便于进行分类决策。
关键模块的Objective - C实现大致代码结构和技术选型
- 数据预处理模块:
- 技术选型:使用Core Image框架来处理图像。
- 代码结构示例:
#import <CoreImage/CoreImage.h>
// 图像归一化
- (CIImage *)normalizeImage:(CIImage *)image toSize:(CGSize)size {
CGAffineTransform transform = CGAffineTransformMakeScale(size.width / image.extent.size.width, size.height / image.extent.size.height);
CIImage *scaledImage = [image imageByApplyingTransform:transform];
CIImage *croppedImage = [scaledImage croppedToRect:CGRectMake(0, 0, size.width, size.height)];
// 像素值归一化
CIVector *extentVector = [CIVector vectorWithCGRect:croppedImage.extent];
CIFilter *normalizeFilter = [CIFilter filterWithName:@"CINormalize"
keysAndValues:kCIInputImageKey, croppedImage,
kCIInputExtentKey, extentVector, nil];
return [normalizeFilter outputImage];
}
// 灰度化
- (CIImage *)grayscaleImage:(CIImage *)image {
CIFilter *grayscaleFilter = [CIFilter filterWithName:@"CIColorControls"
keysAndValues:kCIInputImageKey, image,
@"inputSaturation", @0.0, nil];
return [grayscaleFilter outputImage];
}
- 特征提取层(以简单CNN为例,使用Accelerate框架辅助计算):
- 技术选型:结合Accelerate框架进行高效的矩阵运算。
- 代码结构示例:
#import <Accelerate/Accelerate.h>
// 简单卷积层示例(假设卷积核大小为3x3)
- (float *)convolutionLayer:(float *)inputImage width:(NSUInteger)width height:(NSUInteger)height {
float kernel[9] = {
-1, -1, -1,
-1, 8, -1,
-1, -1, -1
};
float *outputImage = (float *)malloc(width * height * sizeof(float));
vImage_Buffer inBuffer, outBuffer;
inBuffer.data = inputImage;
inBuffer.width = width;
inBuffer.height = height;
inBuffer.rowBytes = width * sizeof(float);
outBuffer.data = outputImage;
outBuffer.width = width;
outBuffer.height = height;
outBuffer.rowBytes = width * sizeof(float);
vImageConv_F32(&inBuffer, &outBuffer, NULL, 0, 0, kernel, 3, 3, 0, kvImageEdgeExtend);
return outputImage;
}
// 最大池化层示例(假设池化大小为2x2)
- (float *)maxPoolingLayer:(float *)inputImage width:(NSUInteger)width height:(NSUInteger)height {
NSUInteger newWidth = width / 2;
NSUInteger newHeight = height / 2;
float *outputImage = (float *)malloc(newWidth * newHeight * sizeof(float));
for (NSUInteger y = 0; y < newHeight; y++) {
for (NSUInteger x = 0; x < newWidth; x++) {
float maxValue = -FLT_MAX;
for (NSUInteger i = 0; i < 2; i++) {
for (NSUInteger j = 0; j < 2; j++) {
NSUInteger index = (y * 2 + i) * width + x * 2 + j;
maxValue = fmaxf(maxValue, inputImage[index]);
}
}
outputImage[y * newWidth + x] = maxValue;
}
}
return outputImage;
}
- 分类层:
- 技术选型:使用基本的矩阵运算和Softmax函数实现。
- 代码结构示例:
// 全连接层
- (float *)fullyConnectedLayer:(float *)inputFeatures numFeatures:(NSUInteger)numFeatures numClasses:(NSUInteger)numClasses {
float *weights = (float *)malloc(numFeatures * numClasses * sizeof(float));
// 初始化权重(这里简单示例,实际应随机初始化并训练)
for (NSUInteger i = 0; i < numFeatures * numClasses; i++) {
weights[i] = 1.0;
}
float *biases = (float *)malloc(numClasses * sizeof(float));
// 初始化偏置(简单示例)
for (NSUInteger i = 0; i < numClasses; i++) {
biases[i] = 0.0;
}
float *output = (float *)malloc(numClasses * sizeof(float));
for (NSUInteger c = 0; c < numClasses; c++) {
float sum = 0.0;
for (NSUInteger f = 0; f < numFeatures; f++) {
sum += inputFeatures[f] * weights[c * numFeatures + f];
}
output[c] = sum + biases[c];
}
free(weights);
free(biases);
return output;
}
// Softmax函数
- (float *)softmax:(float *)input numClasses:(NSUInteger)numClasses {
float maxValue = -FLT_MAX;
for (NSUInteger i = 0; i < numClasses; i++) {
maxValue = fmaxf(maxValue, input[i]);
}
float sum = 0.0;
float *output = (float *)malloc(numClasses * sizeof(float));
for (NSUInteger i = 0; i < numClasses; i++) {
output[i] = exp(input[i] - maxValue);
sum += output[i];
}
for (NSUInteger i = 0; i < numClasses; i++) {
output[i] /= sum;
}
return output;
}