Fix nickname and nsfw pfp scanning (#41)

diff --git a/package.json b/package.json
index 6457877..629c6dd 100644
--- a/package.json
+++ b/package.json
@@ -1,7 +1,7 @@
 {
     "dependencies": {
         "@hokify/agenda": "^6.2.12",
-        "@tensorflow/tfjs": "^4.0.0",
+        "@tensorflow/tfjs": "^3.18.0",
         "@tensorflow/tfjs-node": "^4.2.0",
         "@total-typescript/ts-reset": "^0.3.7",
         "@tsconfig/node18-strictest-esm": "^1.0.0",
@@ -21,7 +21,7 @@
         "mongodb": "^4.7.0",
         "node-fetch": "^3.3.0",
         "node-tesseract-ocr": "^2.2.1",
-        "nsfwjs": "2.4.2",
+        "nsfwjs": "^2.4.2",
         "seedrandom": "^3.0.5",
         "structured-clone": "^0.2.2",
         "systeminformation": "^5.17.3"
diff --git a/src/reflex/scanners.ts b/src/reflex/scanners.ts
index 192be25..53c8c9b 100644
--- a/src/reflex/scanners.ts
+++ b/src/reflex/scanners.ts
@@ -1,13 +1,13 @@
 import fetch from "node-fetch";
-import fs, { writeFileSync, createReadStream } from "fs";
+import { writeFileSync } from "fs";
 import generateFileName from "../utils/temp/generateFileName.js";
 import Tesseract from "node-tesseract-ocr";
 import type Discord from "discord.js";
 import client from "../utils/client.js";
 import { createHash } from "crypto";
-// import * as nsfwjs from "nsfwjs";
+import * as nsfwjs from "nsfwjs";
 // import * as clamscan from "clamscan";
-// import * as tf from "@tensorflow/tfjs-node";
+import * as tf from "@tensorflow/tfjs";
 import EmojiEmbed from "../utils/generateEmojiEmbed.js";
 import getEmojiByName from "../utils/getEmojiByName.js";
 import { ActionRowBuilder, ButtonBuilder, ButtonStyle } from "discord.js";
@@ -21,51 +21,53 @@
     errored?: boolean;
 }
 
-// const model = await nsfwjs.load();
+const nsfw_model = await nsfwjs.load();
 
 export async function testNSFW(link: string): Promise<NSFWSchema> {
-    const [_fileName, hash] = await saveAttachment(link);
+    const [fileStream, hash] = await streamAttachment(link);
     const alreadyHaveCheck = await client.database.scanCache.read(hash);
-    if (alreadyHaveCheck) return { nsfw: alreadyHaveCheck.data };
+    if (alreadyHaveCheck?.nsfw) return { nsfw: alreadyHaveCheck.nsfw };
 
-    // const image = tf.node.decodePng()
+    const image = tf.tensor3d(new Uint8Array(fileStream));
 
-    // const result = await model.classify(image)
+    const predictions = (await nsfw_model.classify(image, 1))[0]!;
+    image.dispose();
 
-    return { nsfw: false };
+    return { nsfw: predictions.className === "Hentai" || predictions.className === "Porn" };
 }
 
 export async function testMalware(link: string): Promise<MalwareSchema> {
-    const [p, hash] = await saveAttachment(link);
+    const [_, hash] = await saveAttachment(link);
     const alreadyHaveCheck = await client.database.scanCache.read(hash);
-    if (alreadyHaveCheck) return { safe: alreadyHaveCheck.data };
-    const data = new URLSearchParams();
-    const f = createReadStream(p);
-    data.append("file", f.read(fs.statSync(p).size));
-    const result = await fetch("https://unscan.p.rapidapi.com/malware", {
-        method: "POST",
-        headers: {
-            "X-RapidAPI-Key": client.config.rapidApiKey,
-            "X-RapidAPI-Host": "unscan.p.rapidapi.com"
-        },
-        body: data
-    })
-        .then((response) =>
-            response.status === 200 ? (response.json() as Promise<MalwareSchema>) : { safe: true, errored: true }
-        )
-        .catch((err) => {
-            console.error(err);
-            return { safe: true, errored: true };
-        });
-    if (!result.errored) {
-        client.database.scanCache.write(hash, result.safe);
-    }
-    return { safe: result.safe };
+    if (alreadyHaveCheck?.malware) return { safe: alreadyHaveCheck.malware };
+    return { safe: true };
+    // const data = new URLSearchParams();
+    // // const f = createReadStream(p);
+    // data.append("file", f.read(fs.statSync(p).size));
+    // const result = await fetch("https://unscan.p.rapidapi.com/malware", {
+    //     method: "POST",
+    //     headers: {
+    //         "X-RapidAPI-Key": client.config.rapidApiKey,
+    //         "X-RapidAPI-Host": "unscan.p.rapidapi.com"
+    //     },
+    //     body: data
+    // })
+    //     .then((response) =>
+    //         response.status === 200 ? (response.json() as Promise<MalwareSchema>) : { safe: true, errored: true }
+    //     )
+    //     .catch((err) => {
+    //         console.error(err);
+    //         return { safe: true, errored: true };
+    //     });
+    // if (!result.errored) {
+    //     client.database.scanCache.write(hash, "malware", result.safe);
+    // }
+    // return { safe: result.safe };
 }
 
 export async function testLink(link: string): Promise<{ safe: boolean; tags: string[] }> {
     const alreadyHaveCheck = await client.database.scanCache.read(link);
-    if (alreadyHaveCheck) return { safe: alreadyHaveCheck.data, tags: [] };
+    if (alreadyHaveCheck?.bad_link) return { safe: alreadyHaveCheck.bad_link, tags: alreadyHaveCheck.tags };
     const scanned: { safe?: boolean; tags?: string[] } = await fetch("https://unscan.p.rapidapi.com/link", {
         method: "POST",
         headers: {
@@ -79,13 +81,19 @@
             console.error(err);
             return { safe: true, tags: [] };
         });
-    client.database.scanCache.write(link, scanned.safe ?? true, []);
+    client.database.scanCache.write(link, "bad_link", scanned.safe ?? true, scanned.tags ?? []);
     return {
         safe: scanned.safe ?? true,
         tags: scanned.tags ?? []
     };
 }
 
+export async function streamAttachment(link: string): Promise<[ArrayBuffer, string]> {
+    const image = await (await fetch(link)).arrayBuffer();
+    const enc = new TextDecoder("utf-8");
+    return [image, createHash("sha512").update(enc.decode(image), "base64").digest("base64")];
+}
+
 export async function saveAttachment(link: string): Promise<[string, string]> {
     const image = await (await fetch(link)).arrayBuffer();
     const fileName = generateFileName(link.split("/").pop()!.split(".").pop()!);
@@ -218,11 +226,10 @@
     const avatarCheck =
         guildData.filters.images.NSFW && (await NSFWCheck(member.user.displayAvatarURL({ forceStatic: true })));
     // Does the username contain an invite
-    const inviteCheck =
-        guildData.filters.invite.enabled && member.user.username.match(/discord\.gg\/[a-zA-Z0-9]+/gi) !== null;
+    const inviteCheck = guildData.filters.invite.enabled && /discord\.gg\/[a-zA-Z0-9]+/gi.test(member.user.username);
     // Does the nickname contain an invite
     const nicknameInviteCheck =
-        guildData.filters.invite.enabled && member.nickname?.match(/discord\.gg\/[a-zA-Z0-9]+/gi) !== null;
+        guildData.filters.invite.enabled && /discord\.gg\/[a-zA-Z0-9]+/gi.test(member.nickname ?? "");
 
     if (
         usernameCheck !== null ||
diff --git a/src/utils/database.ts b/src/utils/database.ts
index 4f94712..a107d06 100644
--- a/src/utils/database.ts
+++ b/src/utils/database.ts
@@ -588,7 +588,9 @@
 interface ScanCacheSchema {
     addedAt: Date;
     hash: string;
-    data: boolean;
+    nsfw?: boolean;
+    malware?: boolean;
+    bad_link?: boolean;
     tags: string[];
 }
 
@@ -600,14 +602,12 @@
     }
 
     async read(hash: string) {
-        // console.log("ScanCache read");
         return await this.scanCache.findOne({ hash: hash });
     }
 
-    async write(hash: string, data: boolean, tags?: string[]) {
-        // console.log("ScanCache write");
+    async write(hash: string, type: "nsfw" | "malware" | "bad_link", data: boolean, tags?: string[]) {
         await this.scanCache.insertOne(
-            { hash: hash, data: data, tags: tags ?? [], addedAt: new Date() },
+            { hash: hash, [type]: data, tags: tags ?? [], addedAt: new Date() },
             collectionOptions
         );
     }