This commit is contained in:
2025-09-04 14:08:51 +08:00
parent 6c00a7e351
commit b6d35219f5

View File

@@ -0,0 +1,578 @@
package org.dromara.easyai.matrixTools;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import java.util.ArrayList;
import java.util.List;
/**
* 矩阵
**/
public class Matrix extends MatrixOperation {
private float[] matrix;//矩阵本体(列主序)
private int x;//矩阵的行数
private int y;//矩阵的列数
private boolean isRowVector = false;//是否是单行矩阵
private boolean isVector = false;//是否是向量
private boolean isZero = false;//是否是单元素矩阵
/**
* 获取Cuda列主序一维数组
*
* @return 获取Cuda列主序一维数组
*/
public float[] getCudaMatrix() {//获取cudaMatrix
return matrix;
}
public Float[] getMatrixModel() {
Float[] matrixModel = new Float[matrix.length];
for (int i = 0; i < matrix.length; i++) {
matrixModel[i] = matrix[i];
}
return matrixModel;
}
public void insertMatrixModel(Float[] matrixModel) {
matrix = new float[matrixModel.length];
for (int i = 0; i < matrix.length; i++) {
matrix[i] = matrixModel[i];
}
}
/**
* 注入Cuda一维数组(列主序)
*
* @param cudaMatrix 注入数组本体
* @param x 行数
* @param y 列数
*/
public void setCudaMatrix(float[] cudaMatrix, int x, int y) {
this.matrix = cudaMatrix;
this.x = x;
this.y = y;
}
/**
* 获取行数
*
* @return 获取行数
*/
public int getX() {//获取行数
return x;
}
/**
* 获取列数
*
* @return 获取列数
*/
public int getY() {//获取列数
return y;
}
/**
* 初始化矩阵
*
* @param x 行数
* @param y 列数
*/
public Matrix(int x, int y) {
matrix = new float[x * y];
this.x = x;
this.y = y;
setState(x, y);
}
/**
* 设置矩阵属性
*
* @param x 行数
* @param y 列数
*/
private void setState(int x, int y) {
if (x == 1 && y == 1) {
isZero = true;
isVector = true;
} else if (x == 1 || y == 1) {
isVector = true;
isRowVector = x == 1;
}
}
public float getSigma() throws Exception {
float sigma = 0;
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
sigma = sigma + getNumber(i, j);
}
}
return sigma;
}
public Matrix scale(boolean scaleX, float size) throws Exception {//缩放矩阵
float value;
if (!scaleX) {//将宽度等比缩放至指定尺寸
value = y / size;
} else {//将高度等比缩放至指定尺寸
value = x / size;
}
int narrowX = (int) (x / value);
int narrowY = (int) (y / value);
if (!scaleX) {
narrowY = (int) size;
} else {
narrowX = (int) size;
}
Matrix matrix = new Matrix(narrowX, narrowY);
for (int i = 0; i < narrowX; i++) {
for (int j = 0; j < narrowY; j++) {
int indexX = (int) (i * value);
int indexY = (int) (j * value);
matrix.setNub(i, j, getNumber(indexX, indexY));
}
}
return matrix;
}
/**
* 计算全矩阵元素平均值
*
* @return 返回当前矩阵全部元素的平均值
*/
public float getAVG() throws Exception {
float sigma = 0;
int s = x * y;
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
sigma = sigma + getNumber(i, j);
}
}
sigma = sigma / s;
return sigma;
}
public float[][] getMatrix() throws Exception {
float[][] matrix = new float[x][y];
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
matrix[i][j] = getNumber(i, j);
}
}
return matrix;
}
/**
* 是否为单行
*
* @return true表示此矩阵为一个单行矩阵
*/
public boolean isRowVector() {
return isRowVector;
}
/**
* 是否是一个向量矩阵
* 单行和单列矩阵都是向量矩阵
*
* @return true表示此矩阵为一个向量矩阵
*/
public boolean isVector() {
return isVector;
}
/**
* 是否是一个单元素矩阵
*
* @return true表示是里面只有一个元素
*/
public boolean isZero() {
return isZero;
}
/**
* 清除矩阵数据
**/
public void clear() {
matrix = new float[x * y];
}
/**
* 初始化矩阵
*
* @param x 行数
* @param y 列数
* @param matr 数据
* @throws Exception
*/
public Matrix(int x, int y, String matr) throws Exception {
matrix = new float[x * y];
this.x = x;
this.y = y;
setState(x, y);
setAll(matr);
}
class Coordinate {//保存行数列数的实体类
Coordinate(Coordinate father, int x, int y) {
this.x = x;
this.y = y;
this.father = father;
coordinateList = new ArrayList<>();
}
Coordinate father;
List<Coordinate> coordinateList;
int x;//路径
int y;//深度
}
private List<Coordinate> coordinateRoot;
private float defNub = 0;//行列式计算结果
private boolean isDo(Coordinate coordinates, int i, int j) {
boolean isOk = false;
if (coordinates != null) {
for (Coordinate coordinate : coordinates.coordinateList) {
if (coordinate.x == i && coordinate.y == j) {
isOk = true;
break;
}
}
}
return isOk;
}
private boolean findRout(Coordinate coordinate, int j, int initi, boolean isDown) {
for (int i = 0; i < x; i++) {//层数
if (!isDown) {
break;
}
int row = i;
if (coordinate == null) {
row = initi;
}
boolean isOk = isNext(coordinate, row, true) && !isDo(coordinate, row, j);
if (isOk) {
Coordinate coordinateNext = new Coordinate(coordinate, row, j);
if (coordinate != null) {
coordinate.coordinateList.add(coordinateNext);
} else {
coordinateRoot.add(coordinateNext);
}
if (coordinateNext.y < (y - 1)) {//深入
j++;
isDown = findRout(coordinateNext, j, initi, isDown);
} else if (coordinate != null && coordinateNext.y > 1 && coordinateNext.x == (x - 1)) {
//缩头
j--;
isDown = findRout(coordinate.father, j, initi, isDown);
} else if (coordinateNext.y == 1) {
isDown = false;
break;
}
} else {
if (i == (x - 1) && j > 1) {//行已经到极限了 缩头
j--;
isDown = findRout(coordinate.father, j, initi, isDown);
} else if (j == 1 && i == (x - 1)) {//跳出
isDown = false;
break;
}
}
}
return isDown;
}
private boolean isNext(Coordinate coordinate, int i, boolean isOk) {
if (coordinate == null) {//此路可走
return true;
}
if (isOk) {
if (coordinate.x != i) {
isOk = isNext(coordinate.father, i, true);
} else {//此路不通
return false;
}
}
return isOk;
}
private void defCalculation(List<Coordinate> coordinates) throws Exception {
for (Coordinate coordinate : coordinates) {
if (!coordinate.coordinateList.isEmpty()) {//继续向丛林深处进发
defCalculation(coordinate.coordinateList);
} else {//到道路的尽头了,进行核算
mulFather(coordinate, 1, new ArrayList<>());
}
}
}
private float mulFather(Coordinate coordinate, float element, List<Coordinate> div) throws Exception {
div.add(coordinate);
element = getNumber(coordinate.x, coordinate.y) * element;
if (coordinate.father != null) {
element = mulFather(coordinate.father, element, div);
} else {//道路尽头
if (parity(div)) {//偶排列
defNub = defNub + element;
} else {//奇排列
defNub = defNub - element;
}
div.clear();
element = 1;
}
return element;
}
/**
* 求矩阵的行列式 递归算法
*
* @return 计算后的值
* @throws Exception 如果矩阵不是一个方阵抛出异常
*/
public float getDet() throws Exception {//求矩阵的行列式
if (x == y) {
coordinateRoot = new ArrayList<>();
for (int i = 0; i < x; i++) {
findRout(null, 0, i, true);
}
defCalculation(coordinateRoot);
} else {
throw new Exception("Matrix is not Square");
}
return defNub;
}
private boolean parity(List<Coordinate> list) {//获取排列奇偶性
boolean parity = true;//默认是偶排列
float[] row = new float[list.size()];
float[] clo = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
row[i] = list.get(i).x + 1;
clo[i] = list.get(i).y + 1;
}
int rowInv = inverseNumber(row);
int cloInv = inverseNumber(clo);
int inverserNumber = rowInv + cloInv;
if (inverserNumber % 2 != 0) {//奇排列
parity = false;
}
return parity;
}
/**
* 给矩阵设置数据
*
* @param messages 数据
* @throws Exception 给出的数据不正确时候会抛出异常
*/
public void setAll(String messages) throws Exception {//全设置矩阵
String[] message = messages.split("#");
if (x == message.length) {
for (int i = 0; i < message.length; i++) {
String mes = message[i];
String[] me = mes.substring(1, mes.length() - 1).split(",");
if (y == me.length) {
y = me.length;
for (int j = 0; j < y; j++) {
setNub(i, j, Float.parseFloat(me[j]));
}
} else {
matrix = null;
throw new Exception("matrix column is not equals");
}
}
} else {
throw new Exception("matrix row is not equals");
}
}
/**
* 将矩阵分块
*
* @param x 要分块的x坐标
* @param y 要分块的y坐标
* @param xSize 分块矩阵的宽度
* @param ySize 分块矩阵的长度
* @return 返回分块后的矩阵
*/
public Matrix getSonOfMatrix(int x, int y, int xSize, int ySize) {
Matrix myMatrix = new Matrix(xSize, ySize);
int xr = 0;
int yr = 0;
try {
for (int i = 0; i < xSize; i++) {
xr = i + x;
for (int j = 0; j < ySize; j++) {
yr = j + y;
if (this.x > xr && this.y > yr) {
myMatrix.setNub(i, j, getNumber(xr, yr));
} else {
throw new Exception("xr:" + xr + ",yr:" + yr + ",x:" + this.x + ",y:" + this.y + ",xSize:" + xSize + ",ySize:" + ySize + ",x:" + x + ",y:" + y);
}
}
}
} catch (Exception e) {
System.out.println("xr:" + xr + ",yr:" + yr);
e.printStackTrace();
}
return myMatrix;
}
/**
* 获取行向量
*
* @param x 你要指定的行数
* @return 返回一个一行的矩阵
* @throws Exception 超出矩阵范围抛出异常
*/
public Matrix getRow(int x) throws Exception {
Matrix myMatrix = new Matrix(1, y);
for (int i = 0; i < y; i++) {
myMatrix.setNub(0, i, getNumber(x, i));
}
return myMatrix;
}
/**
* 获取列向量
*
* @param y 要制定的列数
* @return 返回一个一列的矩阵
* @throws Exception 超出矩阵范围抛出异常
*/
public Matrix getColumn(int y) throws Exception {//获取列向量
Matrix myMatrix = new Matrix(x, 1);
for (int i = 0; i < x; i++) {
myMatrix.setNub(i, 0, getNumber(i, y));
}
return myMatrix;
}
/**
* 返回一个矩阵字符串
*
* @return 返回一个矩阵字符串
*/
public String getString() throws Exception {//矩阵输出字符串
StringBuilder builder = new StringBuilder();
for (int i = 0; i < x; i++) {
builder.append(i + ":[");
for (int j = 0; j < y; j++) {
float number = getNumber(i, j);
if (j == 0) {
builder.append(number);
} else {
builder.append("," + number);
}
}
builder.append("]\r\n");
}
return builder.toString();
}
/**
* 返回一个带坐标的矩阵字符串
*
* @return 返回一个带坐标的矩阵字符串
*/
public String getPositionString() throws Exception {//矩阵输出字符串
StringBuilder builder = new StringBuilder();
for (int i = 0; i < x; i++) {
builder.append(i + ":[");
for (int j = 0; j < y; j++) {
float number = getNumber(i, j);
if (j == 0) {
builder.append(number);
} else {
builder.append("," + j + ":" + number);
}
}
builder.append("]\r\n");
}
return builder.toString();
}
@Override
public String toString() {
try {
return getString();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* 给矩阵设置值
*
* @param x x坐标
* @param y y坐标
* @param number 要设置的值
* @throws Exception 超出矩阵范围抛出
*/
public void setNub(int x, int y, float number) throws Exception {
if (this.x > x && this.y > y && x >= 0 && y >= 0) {
matrix[y * this.x + x] = number;
} else {
throw new Exception("setNub matrix length too little x:" + x + ",y:" + y);
}
}
public Matrix copy() throws Exception {//复制一个矩阵
Matrix myMatrix = new Matrix(this.x, this.y);
for (int i = 0; i < this.x; i++) {
for (int j = 0; j < this.y; j++) {
myMatrix.setNub(i, j, getNumber(i, j));
}
}
return myMatrix;
}
/**
* 取矩阵的数值
*
* @param x x坐标
* @param y y坐标
* @return 返回指定坐标的数值
* @throws Exception 超出矩阵范围抛出
*/
public float getNumber(int x, int y) throws Exception {//从矩阵中拿值
if (this.x > x && this.y > y && x >= 0 && y >= 0) {
return matrix[y * this.x + x];
} else {
System.out.println("x==" + x + ",y==" + y + ",maxX:" + this.x + ",maxY:" + this.y);
throw new Exception("getNumber matrix length too little x:" + x + ",y:" + y);
}
}
/**
* 计算矩阵中某一行向量或者列向量所有元素的和
*
* @param isRow 是否取行向量
* @param index 索取向量在矩阵当中的下标
* @return 返回指定向量所有元素的和
* @throws Exception 超出矩阵范围抛出
*/
public float getSigmaByVector(boolean isRow, int index) throws Exception {
float sigma = 0;
if (index >= 0 && ((isRow && x > index) || (!isRow && y > index))) {
if (isRow) {//取行向量
for (int i = 0; i < y; i++) {
sigma = getNumber(index, i) + sigma;
}
} else {
for (int i = 0; i < x; i++) {
sigma = getNumber(i, index) + sigma;
}
}
} else {
throw new Exception("index 数值下标溢出:" + index);
}
return sigma;
}
}