* Format the request payload based on model type
(prompt: string)
| 442 | * Format the request payload based on model type |
| 443 | */ |
| 444 | formatPayload(prompt: string): string { |
| 445 | const maxTokens = this.config.maxTokens ?? getEnvInt('AWS_SAGEMAKER_MAX_TOKENS') ?? 1024; |
| 446 | const temperature = |
| 447 | typeof this.config.temperature === 'number' |
| 448 | ? this.config.temperature |
| 449 | : (getEnvFloat('AWS_SAGEMAKER_TEMPERATURE') ?? 0.7); |
| 450 | const topP = |
| 451 | typeof this.config.topP === 'number' |
| 452 | ? this.config.topP |
| 453 | : (getEnvFloat('AWS_SAGEMAKER_TOP_P') ?? 1.0); |
| 454 | const stopSequences = this.config.stopSequences || []; |
| 455 | |
| 456 | let payload: any; |
| 457 | |
| 458 | logger.debug(`Formatting payload for model type: ${this.modelType}`); |
| 459 | |
| 460 | switch (this.modelType) { |
| 461 | case 'openai': |
| 462 | try { |
| 463 | // Try to parse as JSON array of messages |
| 464 | const messages = JSON.parse(prompt); |
| 465 | if (Array.isArray(messages)) { |
| 466 | payload = { |
| 467 | messages, |
| 468 | max_tokens: maxTokens, |
| 469 | temperature, |
| 470 | top_p: topP, |
| 471 | stop: stopSequences.length > 0 ? stopSequences : undefined, |
| 472 | }; |
| 473 | } else { |
| 474 | throw new Error('Not valid messages format'); |
| 475 | } |
| 476 | } catch { |
| 477 | // Fall back to text completion format |
| 478 | payload = { |
| 479 | prompt, |
| 480 | max_tokens: maxTokens, |
| 481 | temperature, |
| 482 | top_p: topP, |
| 483 | stop: stopSequences.length > 0 ? stopSequences : undefined, |
| 484 | }; |
| 485 | } |
| 486 | break; |
| 487 | |
| 488 | case 'llama': |
| 489 | // TODO(Will): Can these be consolidated? |
| 490 | try { |
| 491 | const messages = JSON.parse(prompt); |
| 492 | if (Array.isArray(messages)) { |
| 493 | payload = { |
| 494 | inputs: messages, |
| 495 | parameters: { |
| 496 | max_new_tokens: maxTokens, |
| 497 | temperature, |
| 498 | top_p: topP, |
| 499 | stop: stopSequences.length > 0 ? stopSequences : undefined, |
| 500 | }, |
| 501 | }; |
no test coverage detected