@zh-keyboard/recognizer
Version:
130 lines (127 loc) • 4.28 kB
JavaScript
;
//#region rolldown:runtime
var __create = Object.create;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __getProtoOf = Object.getPrototypeOf;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") for (var keys = __getOwnPropNames(from), i = 0, n = keys.length, key; i < n; i++) {
key = keys[i];
if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, {
get: ((k) => from[k]).bind(null, key),
enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable
});
}
return to;
};
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", {
value: mod,
enumerable: true
}) : target, mod));
//#endregion
const __tensorflow_tfjs_converter = __toESM(require("@tensorflow/tfjs-converter"));
const __tensorflow_tfjs_core = __toESM(require("@tensorflow/tfjs-core"));
require("@tensorflow/tfjs-backend-cpu");
//#region src/index.ts
var ZhkRecognizer = class {
model;
dict = [];
canvas;
ctx;
modelPath;
dictPath;
backend;
constructor(options) {
this.modelPath = options.modelPath;
this.dictPath = options.dictPath;
this.backend = options.backend || "cpu";
this.canvas = document.createElement("canvas");
this.canvas.width = this.canvas.height = 64;
this.ctx = this.canvas.getContext("2d", { willReadFrequently: true });
}
async initialize(options) {
const text = await fetch(this.dictPath).then((r) => r.text());
this.dict = text.split("\n");
this.model = await (0, __tensorflow_tfjs_converter.loadGraphModel)(this.modelPath, {
streamWeights: true,
onProgress: options?.onProgress
});
if (this.backend === "webgl") {
await __tensorflow_tfjs_core.setBackend("webgl");
await __tensorflow_tfjs_core.ready();
await this.recognize([
10,
10,
0,
20,
20,
1
]);
} else await __tensorflow_tfjs_core.setBackend("cpu");
return true;
}
async recognize(strokeData) {
if (!this.model) throw new Error("Model not initialized");
const { canvas, ctx, model, dict } = this;
ctx.fillStyle = "white";
ctx.fillRect(0, 0, canvas.width, canvas.height);
const n = strokeData.length / 3;
const strokes = Array.from({ length: n }, (_, i) => ({
x: strokeData[3 * i],
y: strokeData[3 * i + 1],
isEnd: strokeData[3 * i + 2] === 1
}));
let minX = Infinity;
let minY = Infinity;
let maxX = -Infinity;
let maxY = -Infinity;
for (const { x, y } of strokes) {
if (x < minX) minX = x;
if (x > maxX) maxX = x;
if (y < minY) minY = y;
if (y > maxY) maxY = y;
}
const w = maxX - minX || 1;
const h = maxY - minY || 1;
const cx = (minX + maxX) / 2;
const cy = (minY + maxY) / 2;
const scale = Math.min(canvas.width * .9 / w, canvas.height * .9 / h);
ctx.strokeStyle = "black";
ctx.lineWidth = 2;
ctx.lineCap = "round";
ctx.lineJoin = "round";
let last = null;
for (const s of strokes) {
const x = canvas.width / 2 + (s.x - cx) * scale;
const y = canvas.height / 2 + (s.y - cy) * scale;
if (last && !last.isEnd) {
ctx.beginPath();
ctx.moveTo(canvas.width / 2 + (last.x - cx) * scale, canvas.height / 2 + (last.y - cy) * scale);
ctx.lineTo(x, y);
ctx.stroke();
} else {
ctx.beginPath();
ctx.moveTo(x, y);
}
last = s;
}
return __tensorflow_tfjs_core.tidy(() => {
const image = __tensorflow_tfjs_core.browser.fromPixels(canvas, 3);
const floatImage = __tensorflow_tfjs_core.cast(image, "float32");
const normalizedImage = __tensorflow_tfjs_core.div(floatImage, 255);
const batchedImage = __tensorflow_tfjs_core.expandDims(normalizedImage, 0);
const probs = model.predict(batchedImage).dataSync();
const idxs = Array.from(probs.keys()).sort((a, b) => probs[b] - probs[a]).slice(0, 10);
return idxs.map((i) => i < dict.length ? dict[i] : "").filter(Boolean);
});
}
async close() {
this.model?.dispose();
this.model = void 0;
}
};
//#endregion
exports.ZhkRecognizer = ZhkRecognizer
//# sourceMappingURL=index.js.map