|
import express from 'express'; |
|
|
|
import { getPipeline } from '../transformers.js'; |
|
|
|
const TASK = 'text-classification'; |
|
|
|
export const router = express.Router(); |
|
|
|
|
|
|
|
|
|
const cacheObject = new Map(); |
|
|
|
router.post('/labels', async (req, res) => { |
|
try { |
|
const pipe = await getPipeline(TASK); |
|
const result = Object.keys(pipe.model.config.label2id); |
|
return res.json({ labels: result }); |
|
} catch (error) { |
|
console.error(error); |
|
return res.sendStatus(500); |
|
} |
|
}); |
|
|
|
router.post('/', async (req, res) => { |
|
try { |
|
const { text } = req.body; |
|
|
|
|
|
|
|
|
|
|
|
|
|
async function getResult(text) { |
|
if (cacheObject.has(text)) { |
|
return cacheObject.get(text); |
|
} else { |
|
const pipe = await getPipeline(TASK); |
|
const result = await pipe(text, { topk: 5 }); |
|
result.sort((a, b) => b.score - a.score); |
|
cacheObject.set(text, result); |
|
return result; |
|
} |
|
} |
|
|
|
console.debug('Classify input:', text); |
|
const result = await getResult(text); |
|
console.debug('Classify output:', result); |
|
|
|
return res.json({ classification: result }); |
|
} catch (error) { |
|
console.error(error); |
|
return res.sendStatus(500); |
|
} |
|
}); |
|
|