Add File
This commit is contained in:
128
src/main/java/org/dromara/easyai/tools/Knn.java
Normal file
128
src/main/java/org/dromara/easyai/tools/Knn.java
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package org.dromara.easyai.tools;
|
||||||
|
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class Knn extends MatrixOperation {//KNN分类器
|
||||||
|
private Map<Integer, List<Matrix>> featureMap = new HashMap<>();
|
||||||
|
private int length;//向量长度(需要返回)
|
||||||
|
private final int nub;//选择几个人投票
|
||||||
|
|
||||||
|
public Knn(int nub) {
|
||||||
|
this.nub = nub;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setFeatureMap(Map<Integer, List<Matrix>> featureMap) {
|
||||||
|
this.featureMap = featureMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Integer, List<Matrix>> getFeatureMap() {
|
||||||
|
return featureMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void removeType(int type) {
|
||||||
|
featureMap.remove(type);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void revoke(int type, int nub) {//撤销一个类别最新的
|
||||||
|
List<Matrix> list = featureMap.get(type);
|
||||||
|
for (int i = 0; i < nub; i++) {
|
||||||
|
list.remove(list.size() - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNub(int type) {//获取该分类模型的数量
|
||||||
|
int nub = 0;
|
||||||
|
List<Matrix> list = featureMap.get(type);
|
||||||
|
if (list != null) {
|
||||||
|
nub = list.size();
|
||||||
|
}
|
||||||
|
return nub;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insertMatrix(Matrix vector, int tag) throws Exception {
|
||||||
|
if (vector.isVector() && vector.isRowVector()) {
|
||||||
|
if (featureMap.size() == 0) {
|
||||||
|
List<Matrix> list = new ArrayList<>();
|
||||||
|
list.add(vector);
|
||||||
|
featureMap.put(tag, list);
|
||||||
|
length = vector.getY();
|
||||||
|
} else {
|
||||||
|
if (length == vector.getY()) {
|
||||||
|
if (featureMap.containsKey(tag)) {
|
||||||
|
featureMap.get(tag).add(vector);
|
||||||
|
} else {
|
||||||
|
List<Matrix> list = new ArrayList<>();
|
||||||
|
list.add(vector);
|
||||||
|
featureMap.put(tag, list);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Exception("vector length is different");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Exception("this matrix is not vector or rowVector");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void compare(float[] values, int[] types, float value, int type) {
|
||||||
|
for (int i = 0; i < values.length; i++) {
|
||||||
|
float val = values[i];
|
||||||
|
if (val < 0) {
|
||||||
|
values[i] = value;
|
||||||
|
types[i] = type;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
if (value < val) {
|
||||||
|
for (int j = values.length - 2; j >= i; j--) {
|
||||||
|
values[j + 1] = values[j];
|
||||||
|
types[j + 1] = types[j];
|
||||||
|
}
|
||||||
|
values[i] = value;
|
||||||
|
types[i] = type;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getType(Matrix vector) throws Exception {//识别分类
|
||||||
|
int ty = 0;
|
||||||
|
float[] dists = new float[nub];
|
||||||
|
// System.out.println("测试:" + vector.getString());
|
||||||
|
int[] types = new int[nub];
|
||||||
|
for (int i = 0; i < nub; i++) {
|
||||||
|
dists[i] = -1;
|
||||||
|
}
|
||||||
|
for (Map.Entry<Integer, List<Matrix>> entry : featureMap.entrySet()) {
|
||||||
|
int type = entry.getKey();
|
||||||
|
List<Matrix> matrices = entry.getValue();
|
||||||
|
for (Matrix matrix : matrices) {
|
||||||
|
float dist = getEDist(matrix, vector);
|
||||||
|
compare(dists, types, dist, type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
System.out.println(Arrays.toString(types));
|
||||||
|
Map<Integer, Integer> map = new HashMap<>();
|
||||||
|
for (int i = 0; i < nub; i++) {
|
||||||
|
int type = types[i];
|
||||||
|
if (map.containsKey(type)) {
|
||||||
|
map.put(type, map.get(type) + 1);
|
||||||
|
} else {
|
||||||
|
map.put(type, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int max = 0;
|
||||||
|
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
|
||||||
|
int value = entry.getValue();
|
||||||
|
int type = entry.getKey();
|
||||||
|
if (value > max) {
|
||||||
|
ty = type;
|
||||||
|
max = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ty;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user