@apjs/tensor
Version:
A library with a set of functions to facilitate the use of the basic operations of linear algebra
47 lines (44 loc) • 1.69 kB
text/typescript
import { Tensor, TensorLike } from "../class"
import { ensureTensor } from "../utils"
import { reshape, tile } from "./reshape"
export const broadcastArgs = (shape: number[], shape2: number[]) => {
let oldShape = shape.slice(), oldShape2 = shape2.slice()
if (oldShape.length > oldShape2.length) {
while (oldShape2.length < oldShape.length) {
oldShape2.unshift(1)
}
} else {
while (oldShape.length < oldShape2.length) {
oldShape.unshift(1)
}
}
let resultShape = []
for (let index = 0; index < oldShape.length; index++) {
resultShape[index] = Math.max(oldShape[index], oldShape2[index])
}
return resultShape
}
export const broadcastTo = (tensor: Tensor | TensorLike, shape: number[]) => {
let internTensor = ensureTensor(tensor)
let oldShape = internTensor.shape.slice()
if (shape.length < internTensor.rank) {
throw console.error('Rank from shape ' + shape.length + ' is minor than tensor rank ' + internTensor.rank)
}
let result = internTensor
if (shape.length > internTensor.rank) {
let newShape = oldShape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1)
}
result = reshape(internTensor, newShape)
}
let reps = shape.slice()
for (let index = 0; index < shape.length; index++) {
if (result.shape[index] === shape[index]) {
reps[index] = 1;
} else if (result.shape[index] !== 1) {
throw console.error(oldShape + ' cannot be broadcast to ' + shape)
}
}
return tile(result, reps)
}