diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/OpenChatPromptTemplate.test.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/OpenChatPromptTemplate.test.ts new file mode 100644 index 000000000..9462e91a3 --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/OpenChatPromptTemplate.test.ts @@ -0,0 +1,72 @@ +import { chat, instruction, text } from "./OpenChatPromptTemplate"; + +describe("text prompt", () => { + it("should format prompt", () => { + const prompt = text().format("prompt"); + + expect(prompt).toMatchSnapshot(); + }); +}); + +describe("instruction prompt", () => { + it("should format prompt with instruction", () => { + const prompt = instruction().format({ + instruction: "instruction", + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with system and instruction", () => { + const prompt = instruction().format({ + system: "system", + instruction: "instruction", + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with instruction and response prefix", () => { + const prompt = instruction().format({ + instruction: "instruction", + responsePrefix: "response prefix", + }); + + expect(prompt).toMatchSnapshot(); + }); +}); + +describe("chat prompt", () => { + it("should format prompt with user message", () => { + const prompt = chat().format({ + messages: [{ role: "user", content: "user message" }], + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with user-assistant-user messages", () => { + const prompt = chat().format({ + messages: [ + { role: "user", content: "1st user message" }, + { role: "assistant", content: "assistant message" }, + { role: "user", content: "2nd user message" }, + ], + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with system message and user-assistant-user messages", () => { + const prompt = chat().format({ + system: "you are a chatbot", + messages: [ + { role: "user", content: "1st user message" }, + { role: "assistant", content: "assistant message" }, + { role: "user", content: "2nd user message" }, + ], + }); + + expect(prompt).toMatchSnapshot(); + }); +}); diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/OpenChatPromptTemplate.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/OpenChatPromptTemplate.ts new file mode 100644 index 000000000..2609610e2 --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/OpenChatPromptTemplate.ts @@ -0,0 +1,154 @@ +import { TextGenerationPromptTemplate } from "../TextGenerationPromptTemplate"; +import { ChatPrompt } from "./ChatPrompt"; +import { validateContentIsString } from "./ContentPart"; +import { InstructionPrompt } from "./InstructionPrompt"; +import { InvalidPromptError } from "./InvalidPromptError"; + +const END_OF_TURN = "<|end_of_turn|>"; +const USER_PREFIX = "GPT4 Correct User: "; +const ASSISTANT_PREFIX = "GPT4 Correct Assistant:"; + +function userSegment(text: string) { + return `${USER_PREFIX}${text}${END_OF_TURN}`; +} + +function assistantSegment(text: string) { + return `${ASSISTANT_PREFIX} ${text}${END_OF_TURN}`; +} + +function systemSegment(text: string | undefined) { + return text == null ? "" : `${text}${END_OF_TURN}`; +} + +/** + * Formats a text prompt as an OpenChat prompt. + * + * OpenChat prompt template: + * ``` + * GPT4 Correct User: { prompt }<|end_of_turn|>GPT4 Correct Assistant: + * ``` + * + * @see https://github.com/imoneoi/openchat#conversation-templates + */ +export function text(): TextGenerationPromptTemplate { + return { + stopSequences: [END_OF_TURN], + format(prompt) { + return `${userSegment(prompt)}${ASSISTANT_PREFIX}`; + }, + }; +} + +/** + * Formats an instruction prompt as an OpenChat prompt. + * + * OpenChat prompt template with a system prompt: + * ``` + * ${ system prompt }<|end_of_turn|>GPT4 Correct User: ${ instruction }<|end_of_turn|>GPT4 Correct Assistant:${ response prefix } + * ``` + * + * @see https://github.com/imoneoi/openchat#conversation-templates + */ +export function instruction(): TextGenerationPromptTemplate< + InstructionPrompt, + string +> { + return { + stopSequences: [END_OF_TURN], + format(prompt) { + const instruction = validateContentIsString(prompt.instruction, prompt); + + return ( + systemSegment(prompt.system) + + userSegment(instruction) + + ASSISTANT_PREFIX + + (prompt.responsePrefix ?? "") + ); + }, + }; +} + +/** + * Formats a chat prompt as an OpenChat prompt. + * + * OpenChat prompt template: + * ``` + * GPT4 Correct User: ${ user msg 1 }<|end_of_turn|>GPT4 Correct Assistant: ${ model response 1 }<|end_of_turn|>GPT4 Correct User: ${ user msg 2 }<|end_of_turn|>GPT4 Correct Assistant: + * ``` + * + * @see https://github.com/imoneoi/openchat#conversation-templates + */ +export function chat(): TextGenerationPromptTemplate { + return { + stopSequences: [END_OF_TURN], + format(prompt) { + validateOpenChatPrompt(prompt); + + let text = systemSegment(prompt.system); + + for (const { role, content } of prompt.messages) { + switch (role) { + case "user": { + text += userSegment(validateContentIsString(content, prompt)); + break; + } + case "assistant": { + text += assistantSegment(validateContentIsString(content, prompt)); + break; + } + case "tool": { + throw new InvalidPromptError( + "Tool messages are not supported.", + prompt + ); + } + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + return text + ASSISTANT_PREFIX; + }, + }; +} + +/** + * Checks if an OpenChat chat prompt is valid. + * + * - The first message of the chat must be a user message. + * - Then it must alternate between assistant and user messages. + * - The last message must be a user message when submitting to the model. + * + * @throws {@link InvalidPromptError} + */ +export function validateOpenChatPrompt(chatPrompt: ChatPrompt) { + const messages = chatPrompt.messages; + + if (messages.length < 1) { + throw new InvalidPromptError( + "ChatPrompt should have at least one message.", + chatPrompt + ); + } + + for (let i = 0; i < messages.length; i++) { + const expectedRole = i % 2 === 0 ? "user" : "assistant"; + const role = messages[i].role; + + if (role !== expectedRole) { + throw new InvalidPromptError( + `Message at index ${i} should have role '${expectedRole}', but has role '${role}'.`, + chatPrompt + ); + } + } + + if (messages.length % 2 === 0) { + throw new InvalidPromptError( + "The last message must be a user message.", + chatPrompt + ); + } +} diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/OpenChatPromptTemplate.test.ts.snap b/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/OpenChatPromptTemplate.test.ts.snap new file mode 100644 index 000000000..0b2ec2ded --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/OpenChatPromptTemplate.test.ts.snap @@ -0,0 +1,15 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`chat prompt > should format prompt with system message and user-assistant-user messages 1`] = `"you are a chatbot<|end_of_turn|>GPT4 Correct User: 1st user message<|end_of_turn|>GPT4 Correct Assistant: assistant message<|end_of_turn|>GPT4 Correct User: 2nd user message<|end_of_turn|>GPT4 Correct Assistant:"`; + +exports[`chat prompt > should format prompt with user message 1`] = `"GPT4 Correct User: user message<|end_of_turn|>GPT4 Correct Assistant:"`; + +exports[`chat prompt > should format prompt with user-assistant-user messages 1`] = `"GPT4 Correct User: 1st user message<|end_of_turn|>GPT4 Correct Assistant: assistant message<|end_of_turn|>GPT4 Correct User: 2nd user message<|end_of_turn|>GPT4 Correct Assistant:"`; + +exports[`instruction prompt > should format prompt with instruction 1`] = `"GPT4 Correct User: instruction<|end_of_turn|>GPT4 Correct Assistant:"`; + +exports[`instruction prompt > should format prompt with instruction and response prefix 1`] = `"GPT4 Correct User: instruction<|end_of_turn|>GPT4 Correct Assistant:response prefix"`; + +exports[`instruction prompt > should format prompt with system and instruction 1`] = `"system<|end_of_turn|>GPT4 Correct User: instruction<|end_of_turn|>GPT4 Correct Assistant:"`; + +exports[`text prompt > should format prompt 1`] = `"GPT4 Correct User: prompt<|end_of_turn|>GPT4 Correct Assistant:"`; diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts index e163d0272..6c543ffcc 100644 --- a/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts @@ -7,6 +7,7 @@ export * from "./InvalidPromptError"; export * as Llama2Prompt from "./Llama2PromptTemplate"; export * as MistralInstructPrompt from "./MistralInstructPromptTemplate"; export * as NeuralChatPrompt from "./NeuralChatPromptTemplate"; +export * as OpenChatPrompt from "./OpenChatPromptTemplate"; export * from "./PromptTemplateProvider"; export * as SynthiaPrompt from "./SynthiaPromptTemplate"; export * from "./TextPrompt";