|
const express = require('express'); |
|
const router = express.Router(); |
|
const crypto = require('crypto'); |
|
const { GoogleClient } = require('../../../app'); |
|
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); |
|
const { handleError, sendMessage, createOnProgress } = require('../../utils'); |
|
const { getUserKey, checkUserKeyExpiry } = require('../../services/UserService'); |
|
const { setHeaders } = require('../../middleware'); |
|
|
|
router.post('/', setHeaders, async (req, res) => { |
|
const { endpoint, text, parentMessageId, conversationId: oldConversationId } = req.body; |
|
if (text.length === 0) { |
|
return handleError(res, { text: 'Prompt empty or too short' }); |
|
} |
|
if (endpoint !== 'google') { |
|
return handleError(res, { text: 'Illegal request' }); |
|
} |
|
|
|
|
|
const endpointOption = { |
|
examples: req.body?.examples ?? [{ input: { content: '' }, output: { content: '' } }], |
|
promptPrefix: req.body?.promptPrefix ?? null, |
|
key: req.body?.key ?? null, |
|
modelOptions: { |
|
model: req.body?.model ?? 'chat-bison', |
|
modelLabel: req.body?.modelLabel ?? null, |
|
temperature: req.body?.temperature ?? 0.2, |
|
maxOutputTokens: req.body?.maxOutputTokens ?? 1024, |
|
topP: req.body?.topP ?? 0.95, |
|
topK: req.body?.topK ?? 40, |
|
}, |
|
}; |
|
|
|
const availableModels = ['chat-bison', 'text-bison', 'codechat-bison']; |
|
if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) { |
|
return handleError(res, { text: 'Illegal request: model' }); |
|
} |
|
|
|
const conversationId = oldConversationId || crypto.randomUUID(); |
|
|
|
|
|
return await ask({ |
|
text, |
|
endpointOption, |
|
conversationId, |
|
parentMessageId, |
|
req, |
|
res, |
|
}); |
|
}); |
|
|
|
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => { |
|
let userMessage; |
|
let userMessageId; |
|
|
|
let responseMessageId; |
|
let lastSavedTimestamp = 0; |
|
const { overrideParentMessageId = null } = req.body; |
|
const user = req.user.id; |
|
|
|
try { |
|
const getReqData = (data = {}) => { |
|
for (let key in data) { |
|
if (key === 'userMessage') { |
|
userMessage = data[key]; |
|
userMessageId = data[key].messageId; |
|
} else if (key === 'responseMessageId') { |
|
responseMessageId = data[key]; |
|
|
|
|
|
} else if (!conversationId && key === 'conversationId') { |
|
conversationId = data[key]; |
|
} |
|
} |
|
|
|
sendMessage(res, { message: userMessage, created: true }); |
|
}; |
|
|
|
const { onProgress: progressCallback } = createOnProgress({ |
|
onProgress: ({ text: partialText }) => { |
|
const currentTimestamp = Date.now(); |
|
if (currentTimestamp - lastSavedTimestamp > 500) { |
|
lastSavedTimestamp = currentTimestamp; |
|
saveMessage({ |
|
messageId: responseMessageId, |
|
sender: 'PaLM2', |
|
conversationId, |
|
parentMessageId: overrideParentMessageId || userMessageId, |
|
text: partialText, |
|
unfinished: true, |
|
cancelled: false, |
|
error: false, |
|
user, |
|
}); |
|
} |
|
}, |
|
}); |
|
|
|
const abortController = new AbortController(); |
|
|
|
const isUserProvided = process.env.PALM_KEY === 'user_provided'; |
|
|
|
let key; |
|
if (endpointOption.key && isUserProvided) { |
|
checkUserKeyExpiry( |
|
endpointOption.key, |
|
'Your GOOGLE_TOKEN has expired. Please provide your token again.', |
|
); |
|
key = await getUserKey({ userId: user, name: 'google' }); |
|
key = JSON.parse(key); |
|
delete endpointOption.key; |
|
console.log('Using service account key provided by User for PaLM models'); |
|
} |
|
|
|
try { |
|
key = require('../../../data/auth.json'); |
|
} catch (e) { |
|
console.log('No \'auth.json\' file (service account key) found in /api/data/ for PaLM models'); |
|
} |
|
|
|
const clientOptions = { |
|
|
|
reverseProxyUrl: process.env.GOOGLE_REVERSE_PROXY || null, |
|
proxy: process.env.PROXY || null, |
|
...endpointOption, |
|
}; |
|
|
|
const client = new GoogleClient(key, clientOptions); |
|
|
|
let response = await client.sendMessage(text, { |
|
getReqData, |
|
user, |
|
conversationId, |
|
parentMessageId, |
|
overrideParentMessageId, |
|
onProgress: progressCallback.call(null, { |
|
res, |
|
text, |
|
parentMessageId: overrideParentMessageId || userMessageId, |
|
}), |
|
abortController, |
|
}); |
|
|
|
if (overrideParentMessageId) { |
|
response.parentMessageId = overrideParentMessageId; |
|
} |
|
|
|
await saveConvo(user, { |
|
...endpointOption, |
|
...endpointOption.modelOptions, |
|
conversationId, |
|
endpoint: 'google', |
|
}); |
|
|
|
await saveMessage({ ...response, user }); |
|
sendMessage(res, { |
|
title: await getConvoTitle(user, conversationId), |
|
final: true, |
|
conversation: await getConvo(user, conversationId), |
|
requestMessage: userMessage, |
|
responseMessage: response, |
|
}); |
|
res.end(); |
|
} catch (error) { |
|
console.error(error); |
|
const errorMessage = { |
|
messageId: responseMessageId, |
|
sender: 'PaLM2', |
|
conversationId, |
|
parentMessageId, |
|
unfinished: false, |
|
cancelled: false, |
|
error: true, |
|
text: error.message, |
|
}; |
|
await saveMessage({ ...errorMessage, user }); |
|
handleError(res, errorMessage); |
|
} |
|
}; |
|
|
|
module.exports = router; |
|
|