diff --git a/backend/composer.json b/backend/composer.json index ffd0e25..d913c4b 100644 --- a/backend/composer.json +++ b/backend/composer.json @@ -19,6 +19,7 @@ "symfony/dotenv": "6.4.*", "symfony/flex": "^2", "symfony/framework-bundle": "6.4.*", + "symfony/http-client": "6.4.*", "symfony/property-access": "6.4.*", "symfony/property-info": "6.4.*", "symfony/runtime": "6.4.*", diff --git a/backend/composer.lock b/backend/composer.lock index 9b62630..2aa2743 100644 --- a/backend/composer.lock +++ b/backend/composer.lock @@ -4,7 +4,7 @@ "Read more about it at https://getcomposer.org/doc/01-basic-usage.md#installing-dependencies", "This file is @generated automatically" ], - "content-hash": "58ca9f6d53632372fae9dee2c6c72aa7", + "content-hash": "f41287711c3c1d476ebbca47f5b529b5", "packages": [ { "name": "doctrine/annotations", @@ -3123,6 +3123,177 @@ ], "time": "2025-03-23T16:46:24+00:00" }, + { + "name": "symfony/http-client", + "version": "v6.4.19", + "source": { + "type": "git", + "url": "https://github.com/symfony/http-client.git", + "reference": "3294a433fc9d12ae58128174896b5b1822c28dad" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/symfony/http-client/zipball/3294a433fc9d12ae58128174896b5b1822c28dad", + "reference": "3294a433fc9d12ae58128174896b5b1822c28dad", + "shasum": "" + }, + "require": { + "php": ">=8.1", + "psr/log": "^1|^2|^3", + "symfony/deprecation-contracts": "^2.5|^3", + "symfony/http-client-contracts": "~3.4.4|^3.5.2", + "symfony/service-contracts": "^2.5|^3" + }, + "conflict": { + "php-http/discovery": "<1.15", + "symfony/http-foundation": "<6.3" + }, + "provide": { + "php-http/async-client-implementation": "*", + "php-http/client-implementation": "*", + "psr/http-client-implementation": "1.0", + "symfony/http-client-implementation": "3.0" + }, + "require-dev": { + "amphp/amp": "^2.5", + "amphp/http-client": "^4.2.1", + "amphp/http-tunnel": "^1.0", + "amphp/socket": "^1.1", + "guzzlehttp/promises": "^1.4|^2.0", + "nyholm/psr7": "^1.0", + "php-http/httplug": "^1.0|^2.0", + "psr/http-client": "^1.0", + "symfony/dependency-injection": "^5.4|^6.0|^7.0", + "symfony/http-kernel": "^5.4|^6.0|^7.0", + "symfony/messenger": "^5.4|^6.0|^7.0", + "symfony/process": "^5.4|^6.0|^7.0", + "symfony/stopwatch": "^5.4|^6.0|^7.0" + }, + "type": "library", + "autoload": { + "psr-4": { + "Symfony\\Component\\HttpClient\\": "" + }, + "exclude-from-classmap": [ + "/Tests/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Nicolas Grekas", + "email": "p@tchwork.com" + }, + { + "name": "Symfony Community", + "homepage": "https://symfony.com/contributors" + } + ], + "description": "Provides powerful methods to fetch HTTP resources synchronously or asynchronously", + "homepage": "https://symfony.com", + "keywords": [ + "http" + ], + "support": { + "source": "https://github.com/symfony/http-client/tree/v6.4.19" + }, + "funding": [ + { + "url": "https://symfony.com/sponsor", + "type": "custom" + }, + { + "url": "https://github.com/fabpot", + "type": "github" + }, + { + "url": "https://tidelift.com/funding/github/packagist/symfony/symfony", + "type": "tidelift" + } + ], + "time": "2025-02-13T09:55:13+00:00" + }, + { + "name": "symfony/http-client-contracts", + "version": "v3.5.2", + "source": { + "type": "git", + "url": "https://github.com/symfony/http-client-contracts.git", + "reference": "ee8d807ab20fcb51267fdace50fbe3494c31e645" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/symfony/http-client-contracts/zipball/ee8d807ab20fcb51267fdace50fbe3494c31e645", + "reference": "ee8d807ab20fcb51267fdace50fbe3494c31e645", + "shasum": "" + }, + "require": { + "php": ">=8.1" + }, + "type": "library", + "extra": { + "thanks": { + "url": "https://github.com/symfony/contracts", + "name": "symfony/contracts" + }, + "branch-alias": { + "dev-main": "3.5-dev" + } + }, + "autoload": { + "psr-4": { + "Symfony\\Contracts\\HttpClient\\": "" + }, + "exclude-from-classmap": [ + "/Test/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Nicolas Grekas", + "email": "p@tchwork.com" + }, + { + "name": "Symfony Community", + "homepage": "https://symfony.com/contributors" + } + ], + "description": "Generic abstractions related to HTTP clients", + "homepage": "https://symfony.com", + "keywords": [ + "abstractions", + "contracts", + "decoupling", + "interfaces", + "interoperability", + "standards" + ], + "support": { + "source": "https://github.com/symfony/http-client-contracts/tree/v3.5.2" + }, + "funding": [ + { + "url": "https://symfony.com/sponsor", + "type": "custom" + }, + { + "url": "https://github.com/fabpot", + "type": "github" + }, + { + "url": "https://tidelift.com/funding/github/packagist/symfony/symfony", + "type": "tidelift" + } + ], + "time": "2024-12-07T08:49:48+00:00" + }, { "name": "symfony/http-foundation", "version": "v6.4.18", diff --git a/backend/src/Application/Command/TestChatSessionCommand.php b/backend/src/Application/Command/TestChatSessionCommand.php new file mode 100644 index 0000000..1a0b8ff --- /dev/null +++ b/backend/src/Application/Command/TestChatSessionCommand.php @@ -0,0 +1,104 @@ +addOption('message', 'm', InputOption::VALUE_REQUIRED, 'Initial message to send to the chat session', 'Hello, this is a test message.') + ->addOption('max-steps', null, InputOption::VALUE_REQUIRED, 'Maximum number of steps to run', 3) + ->addOption('system-prompt', 's', InputOption::VALUE_REQUIRED, 'System prompt to use', 'You are a helpful assistant. Keep your answers brief.'); + } + + protected function execute(InputInterface $input, OutputInterface $output): int + { + $io = new SymfonyStyle($input, $output); + $message = $input->getOption('message'); + $maxSteps = (int) $input->getOption('max-steps'); + $systemPrompt = $input->getOption('system-prompt'); + + $io->title('Testing Chat Session Interface'); + + try { + $chatSession = new ChatSession( + $this->chatProvider, + new ToolCollection([]) + ); + + // Add a chat listener for real-time messaging + $chatSession->addChatListener(function (MessageCollection $messages) use ($io) { + $lastMessage = $messages->getLastMessage(); + $role = $lastMessage->getRole(); + $content = $lastMessage->getContent(); + + if ($content !== null) { + $io->section(ucfirst($role) . ' Message'); + $io->writeln($content); + } + + $toolCalls = $lastMessage->getToolCalls(); + if (count($toolCalls) > 0) { + $io->section('Tool Calls'); + foreach ($toolCalls as $toolCall) { + $io->writeln('- ' . $toolCall->getName() . ': ' . json_encode($toolCall->getArguments())); + } + } + + $toolResult = $lastMessage->getToolResult(); + if ($toolResult !== null) { + $io->section('Tool Result'); + $io->writeln('Tool: ' . $toolResult->getToolName()); + $io->writeln('ID: ' . $toolResult->getToolCallId()); + } + }); + + // Set system prompt + $io->section('Setting System Prompt'); + $io->writeln($systemPrompt); + $chatSession->system($systemPrompt); + + // Send user message + $io->section('Sending User Message'); + $io->writeln($message); + $chatSession->user($message); + + // Commit the conversation for the specified number of steps + $io->section('Committing Conversation'); + $chatSession->commit($maxSteps); + + // Summary + $io->success('Chat session test completed successfully.'); + + return Command::SUCCESS; + } catch (\Exception $e) { + $io->error('Error during chat session test: ' . $e->getMessage()); + + return Command::FAILURE; + } + } +} \ No newline at end of file diff --git a/backend/src/Domain/Chat/ChatProviderInterface.php b/backend/src/Domain/Chat/ChatProviderInterface.php new file mode 100644 index 0000000..0db5543 --- /dev/null +++ b/backend/src/Domain/Chat/ChatProviderInterface.php @@ -0,0 +1,10 @@ + $choices + * @param list $messageHistory + */ + public function __construct( + private readonly array $choices = [], + private readonly array $messageHistory = [], + ) { + } + + /** + * @return list + */ + public function getChoices(): array + { + return $this->choices; + } + + /** + * @return list + */ + public function getMessageHistory(): array + { + return $this->messageHistory; + } + + public function getFirstChoice(): ?Choice + { + if (empty($this->choices)) { + return null; + } + + return $this->choices[0]; + } + + public function getContent(): ?string + { + $firstChoice = $this->getFirstChoice(); + if ($firstChoice === null) { + return null; + } + + return $firstChoice->getContent(); + } + + /** + * @return list + */ + public function getToolCalls(): array + { + $toolCalls = []; + foreach ($this->choices as $choice) { + if (!empty($choice->getToolCalls())) { + foreach ($choice->getToolCalls() as $toolCall) { + $toolCalls[] = $toolCall; + } + } + } + + return $toolCalls; + } +} diff --git a/backend/src/Domain/Chat/ChatSession.php b/backend/src/Domain/Chat/ChatSession.php new file mode 100644 index 0000000..3210ba4 --- /dev/null +++ b/backend/src/Domain/Chat/ChatSession.php @@ -0,0 +1,94 @@ + + */ + private array $chatListeners = []; + + public function __construct( + private readonly ChatProviderInterface $chatProvider, + private readonly ToolCollection $toolCollection = new ToolCollection([]), + ) { + $this->messages = new MessageCollection(); + } + + public function addChatListener(callable $listener): void + { + $this->chatListeners[] = $listener; + } + + public function system(string $message): void + { + $this->addMessage(Message::fromSystem($message)); + } + + public function user(string $message): void + { + $this->addMessage(Message::fromUser($message)); + } + + public function getMessages(): MessageCollection + { + return $this->messages; + } + + public function commit(int $maxSteps = 10, int $forcedToolCalls = 0, bool $reasoning = true): void + { + if ($maxSteps <= 0) { + throw new Exception('Max steps reached'); + } + + $result = $this->chatProvider->chat($this->messages, $this->toolCollection, $forcedToolCalls > 0, $reasoning); + + $choices = $result->getChoices(); + if (count($choices) === 0) { + throw new Exception('No choices found'); + } + + $this->addMessage(Message::fromAssistant($result->getContent(), $result->getToolCalls())); + + if (count($result->getToolCalls()) === 0) { + return; + } + + foreach ($result->getToolCalls() as $toolCall) { + $tool = $this->toolCollection->findTool($toolCall->getName()); + if ($tool === null) { + continue; + } + + $toolResult = $tool->execute($toolCall->getArguments(), []); + + $this->addMessage(Message::fromToolResult( + $toolResult, + $toolCall->getId(), + $toolCall->getName(), + )); + } + + $this->commit($maxSteps - 1, $forcedToolCalls - 1); + } + + private function notifyChatListeners(MessageCollection $messages): void + { + foreach ($this->chatListeners as $listener) { + $listener($messages); + } + } + + private function addMessage(Message $message): void + { + $this->messages->addMessage($message); + $this->notifyChatListeners($this->messages); + } +} diff --git a/backend/src/Domain/Chat/Choice.php b/backend/src/Domain/Chat/Choice.php new file mode 100644 index 0000000..9f4cd93 --- /dev/null +++ b/backend/src/Domain/Chat/Choice.php @@ -0,0 +1,124 @@ + $contentFilterResults + * @param array $toolCalls + * @param array|null $logprobs + * @param array|null $toolResult + */ + public function __construct( + private readonly array $contentFilterResults, + private readonly string $finishReason, + private readonly int $index, + private readonly ?array $logprobs, + private readonly ?string $content, + private readonly ?string $refusal, + private readonly string $role, + private readonly array $toolCalls, + private readonly ?array $toolResult = null, + ) { + } + + /** + * @return array + */ + public function getContentFilterResults(): array + { + return $this->contentFilterResults; + } + + public function getFinishReason(): string + { + return $this->finishReason; + } + + public function getIndex(): int + { + return $this->index; + } + + /** + * @return array|null + */ + public function getLogprobs(): ?array + { + return $this->logprobs; + } + + public function getContent(): ?string + { + return $this->content; + } + + public function getRefusal(): ?string + { + return $this->refusal; + } + + public function getRole(): string + { + return $this->role; + } + + /** + * @return array + */ + public function getToolCalls(): array + { + return $this->toolCalls; + } + + /** + * @return array|null + */ + public function getToolResult(): ?array + { + return $this->toolResult; + } + + /** + * @param array{ + * content_filter_results?: array, + * finish_reason?: string, + * index?: int, + * logprobs?: ?array, + * message: array{ + * content: ?string, + * refusal: ?string, + * role: string, + * tool_calls: array + * } + * } $data + */ + public static function fromArray(array $data): self + { + $toolCalls = array_map( + static fn (array $toolCall): ToolCall => ToolCall::fromArray($toolCall), + $data['message']['tool_calls'] ?? [] + ); + + $toolResult = $data['message']['tool_result'] ?? null; + + return new self( + contentFilterResults: $data['content_filter_results'] ?? [], + finishReason: $data['finish_reason'] ?? '', + index: $data['index'] ?? 0, + logprobs: $data['logprobs'] ?? null, + content: $data['message']['content'] ?? null, + refusal: $data['message']['refusal'] ?? null, + role: $data['message']['role'] ?? '', + toolCalls: $toolCalls, + toolResult: $toolResult, + ); + } +} diff --git a/backend/src/Domain/Chat/Message.php b/backend/src/Domain/Chat/Message.php new file mode 100644 index 0000000..f09ece2 --- /dev/null +++ b/backend/src/Domain/Chat/Message.php @@ -0,0 +1,138 @@ +|null $toolCalls + */ + public function __construct( + private readonly string $role, + private readonly ?string $content, + private readonly ?array $toolCalls = null, + private readonly ?ToolResult $toolResult = null, + private readonly DateTimeInterface $createdAt = new DateTimeImmutable(), + ) { + } + + public function getRole(): string + { + return $this->role; + } + + public function getContent(): ?string + { + return $this->content; + } + + public function getTwig(): ?string + { + $content = $this->getContent() ?? '```twig\n{}\n```'; + $matches = []; + if (preg_match('/```twig\s*([\s\S]*?)\s*```/m', $content, $matches) > 0) { + return $matches[1]; + } + + return null; + } + + /** + * @return array|null + */ + public function getJson(): ?array + { + $content = $this->getContent() ?? '```json\n{}\n```'; + $matches = []; + if (preg_match('/```json\s*([\s\S]*?)\s*```/m', $content, $matches) > 0) { + $decoded = json_decode($matches[1], true); + return is_array($decoded) ? $decoded : null; + } + + return null; + } + + public function getCreatedAt(): DateTimeInterface + { + return $this->createdAt; + } + + /** + * @return array + */ + public function getToolCalls(): array + { + return $this->toolCalls ?? []; + } + + public function getToolResult(): ?ToolResult + { + if (isset($this->toolResult)) { + return $this->toolResult; + } + + return null; + } + + /** + * @return array + */ + public function toArray(): array + { + $result = [ + 'role' => $this->role, + ]; + + if ($this->content !== null) { + $result['content'] = $this->content; + } + + if ($this->toolResult !== null) { + $result['tool_call_id'] = $this->toolResult->getToolCallId(); + } + + if ($this->toolCalls !== null && $this->toolCalls !== []) { + $result['tool_calls'] = array_map( + fn (ToolCall $toolCall) => [ + 'id' => $toolCall->getId(), + 'type' => $toolCall->getType(), + 'function' => [ + 'name' => $toolCall->getName(), + 'arguments' => json_encode($toolCall->getArguments()), + ], + ], + $this->toolCalls + ); + } + + return $result; + } + + public static function fromUser(string $content): self + { + return new self(role: 'user', content: $content); + } + + public static function fromSystem(string $content): self + { + return new self(role: 'system', content: $content); + } + + /** + * @param array|null $toolCalls + */ + public static function fromAssistant(?string $content = null, ?array $toolCalls = null): self + { + return new self(role: 'assistant', content: $content, toolCalls: $toolCalls); + } + + public static function fromToolResult(string $content, string $toolCallId, string $toolName): self + { + return new self(role: 'tool', content: $content, toolResult: new ToolResult($toolCallId, $toolName)); + } +} diff --git a/backend/src/Domain/Chat/MessageCollection.php b/backend/src/Domain/Chat/MessageCollection.php new file mode 100644 index 0000000..2854479 --- /dev/null +++ b/backend/src/Domain/Chat/MessageCollection.php @@ -0,0 +1,212 @@ + + */ +final class MessageCollection implements IteratorAggregate, Countable +{ + /** + * @var list + */ + private array $messages = []; + + /** + * @param list $messages + */ + public function __construct(array $messages = []) + { + $this->messages = $messages; + } + + /** + * @return list + */ + public function getMessagesSortedByCreatedAtDesc(): array + { + $messages = $this->messages; + usort($messages, fn (Message $a, Message $b) => $b->getCreatedAt()->getTimestamp() <=> $a->getCreatedAt()->getTimestamp()); + + return $messages; + } + + public function getLastMessage(): Message + { + $messages = $this->getMessagesSortedByCreatedAtDesc(); + + return $messages[0]; + } + + /** + * @return list + */ + public function getLastMessages(int $count): array + { + $messages = $this->getMessagesSortedByCreatedAtDesc(); + + return array_slice($messages, 0, $count); + } + + public function addMessage(Message $message): void + { + $this->messages[] = $message; + } + + /** + * @return ArrayIterator + */ + public function getIterator(): ArrayIterator + { + return new ArrayIterator($this->messages); + } + + /** + * @return list + */ + public function toArray(): array + { + return $this->messages; + } + + /** + * @return list + */ + public function getToolCalls(): array + { + $toolCalls = []; + foreach ($this->messages as $message) { + foreach ($message->getToolCalls() as $toolCall) { + $toolCalls[] = $toolCall; + } + } + + return $toolCalls; + } + + /** + * @return list + */ + public function getToolCallsByToolName(string $toolName): array + { + return array_values(array_filter($this->getToolCalls(), fn (ToolCall $toolCall) => $toolCall->getName() === $toolName)); + } + + public function getToolCallsById(string $toolCallId): ?ToolCall + { + foreach ($this->getToolCalls() as $toolCall) { + if ($toolCall->getId() === $toolCallId) { + return $toolCall; + } + } + + return null; + } + + /** + * @return array|null + */ + public function getToolResultById(string $toolCallId): ?array + { + foreach ($this->getToolCalls() as $toolCall) { + if ($toolCall->getId() === $toolCallId) { + return $toolCall->getArguments(); + } + } + + return null; + } + + /** + * @return list + */ + public function getMessagesWithToolCallByToolName(string $toolName): array + { + $result = []; + foreach ($this->getMessagesSortedByCreatedAtDesc() as $message) { + $toolCalls = $message->getToolCalls(); + foreach ($toolCalls as $toolCall) { + if ($toolCall->getName() === $toolName) { + $result[] = $message; + break; + } + } + } + return $result; + } + + /** + * @return list + */ + public function getMessagesWithToolResults(): array + { + $result = []; + foreach ($this->getMessagesSortedByCreatedAtDesc() as $message) { + if ($message->getToolResult() !== null) { + $result[] = $message; + } + } + return $result; + } + + /** + * @return list + */ + public function getMessagesWithToolResultByToolName(string $toolName): array + { + $result = []; + foreach ($this->getMessagesWithToolResults() as $message) { + if ($message->getToolResult()?->getToolName() === $toolName) { + $result[] = $message; + } + } + return $result; + } + + public function getLatestMessageWithToolResultByToolName(string $toolName): ?Message + { + $toolResults = $this->getMessagesWithToolResultByToolName($toolName); + + if (count($toolResults) === 0) { + return null; + } + + return $toolResults[0]; + } + + public function getLatestMessageWithToolCallByToolName(string $toolName): ?Message + { + $messagesWithToolCall = $this->getMessagesWithToolCallByToolName($toolName); + + if (count($messagesWithToolCall) === 0) { + return null; + } + + return $messagesWithToolCall[0]; + } + + public function getLatestToolCallByToolName(string $toolName): ?ToolCall + { + $messageWithToolCall = $this->getLatestMessageWithToolCallByToolName($toolName); + + if ($messageWithToolCall === null) { + return null; + } + + $toolCalls = array_filter($messageWithToolCall->getToolCalls(), fn (ToolCall $toolCall) => $toolCall->getName() === $toolName); + return !empty($toolCalls) ? array_values($toolCalls)[0] : null; + } + + public function count(): int + { + return count($this->messages); + } +} diff --git a/backend/src/Domain/Chat/ShouldStopResult.php b/backend/src/Domain/Chat/ShouldStopResult.php new file mode 100644 index 0000000..066c87b --- /dev/null +++ b/backend/src/Domain/Chat/ShouldStopResult.php @@ -0,0 +1,24 @@ +shouldStop; + } + + public function getErrorPrompt(): ?string + { + return $this->errorPrompt; + } +} diff --git a/backend/src/Domain/Chat/ToolCall.php b/backend/src/Domain/Chat/ToolCall.php new file mode 100644 index 0000000..b6b2482 --- /dev/null +++ b/backend/src/Domain/Chat/ToolCall.php @@ -0,0 +1,67 @@ + $arguments + */ + public function __construct( + private readonly string $id, + private readonly string $type, + private readonly string $name, + private readonly array $arguments, + ) { + } + + public function getId(): string + { + return $this->id; + } + + public function getType(): string + { + return $this->type; + } + + public function getName(): string + { + return $this->name; + } + + /** + * @return array + */ + public function getArguments(): array + { + return $this->arguments; + } + + /** + * @param array{ + * id: string, + * type: string, + * function: array{ + * name: string, + * arguments: string + * } + * } $data + */ + public static function fromArray(array $data): self + { + $decodedArgs = json_decode($data['function']['arguments'], true); + + /** @var array $arguments */ + $arguments = is_array($decodedArgs) ? $decodedArgs : []; + + return new self( + id: $data['id'], + type: $data['type'], + name: $data['function']['name'], + arguments: $arguments, + ); + } +} diff --git a/backend/src/Domain/Chat/ToolCollection.php b/backend/src/Domain/Chat/ToolCollection.php new file mode 100644 index 0000000..13b9a2f --- /dev/null +++ b/backend/src/Domain/Chat/ToolCollection.php @@ -0,0 +1,43 @@ + + */ +final class ToolCollection extends \ArrayObject +{ + /** + * @param list $tools + */ + public function __construct( + private readonly iterable $tools, + ) { + } + + /** + * @return iterable + */ + public function toArray(): array + { + return $this->tools; + } + + public function count(): int + { + return count($this->tools); + } + + public function findTool(string $name): ?ToolInterface + { + foreach ($this->tools as $tool) { + if ($tool->getName() === $name) { + return $tool; + } + } + + return null; + } +} diff --git a/backend/src/Domain/Chat/ToolInterface.php b/backend/src/Domain/Chat/ToolInterface.php new file mode 100644 index 0000000..2522db7 --- /dev/null +++ b/backend/src/Domain/Chat/ToolInterface.php @@ -0,0 +1,40 @@ +> + */ + public function getArguments(): array; + + /** + * Get the required arguments of the tool. + * + * @return array + */ + public function getRequiredArguments(): array; + + /** + * Execute the tool with the given arguments. + * + * @param array $arguments + * @param array $context + */ + public function execute(array $arguments, array $context = []): string; +} diff --git a/backend/src/Domain/Chat/ToolProvider.php b/backend/src/Domain/Chat/ToolProvider.php new file mode 100644 index 0000000..4dca085 --- /dev/null +++ b/backend/src/Domain/Chat/ToolProvider.php @@ -0,0 +1,38 @@ + $tools + */ + public function __construct( + #[AutowireIterator(tag: 'app.ai_proxy.tool')] + private readonly iterable $tools, + ) { + } + + public function getTools(): ToolCollection + { + $toolsArray = iterator_to_array($this->tools); + // Ensure it's a list (consecutive integer keys starting from 0) + $toolsList = array_values($toolsArray); + return new ToolCollection($toolsList); + } + + public function getToolByClass(string $class): ?ToolInterface + { + foreach ($this->tools as $tool) { + if ($tool instanceof $class) { + return $tool; + } + } + + return null; + } +} diff --git a/backend/src/Domain/Chat/ToolResult.php b/backend/src/Domain/Chat/ToolResult.php new file mode 100644 index 0000000..3b727f2 --- /dev/null +++ b/backend/src/Domain/Chat/ToolResult.php @@ -0,0 +1,24 @@ +toolCallId; + } + + public function getToolName(): string + { + return $this->toolName; + } +} diff --git a/backend/src/Infrastructure/Chat/OpenAIChatProvider.php b/backend/src/Infrastructure/Chat/OpenAIChatProvider.php new file mode 100644 index 0000000..60aa0e3 --- /dev/null +++ b/backend/src/Infrastructure/Chat/OpenAIChatProvider.php @@ -0,0 +1,98 @@ + $reasoning ? 'o4-mini' : 'gpt-4o-mini', + 'messages' => array_map(function (Message $message) { + return $message->toArray(); + }, $messages->toArray()), + ]; + + if ($tools->count() > 0) { + $payload['tool_choice'] = $forceToolCalls ? 'required' : 'auto'; + $payload['tools'] = []; + /** @var ToolInterface $tool */ + foreach ($tools->toArray() as $tool) { + $payload['tools'][] = [ + 'type' => 'function', + 'function' => [ + 'name' => $tool->getName(), + 'description' => $tool->getDescription(), + 'parameters' => [ + 'type' => 'object', + 'properties' => $tool->getArguments(), + 'required' => $tool->getRequiredArguments(), + ], + ], + ]; + } + } + + try { + /** @var ResponseInterface $response */ + $response = $this->httpClient->request('POST', "https://api.openai.com/v1/chat/completions", [ + 'headers' => [ + 'Content-Type' => 'application/json', + 'Authorization' => 'Bearer ' . $this->openAiApiKey, + ], + 'json' => $payload, + 'timeout' => 60 * 10, + 'max_duration' => 60 * 10, + ]); + + $responseArray = $response->toArray(false); + + if (isset($responseArray['choices']) && is_array($responseArray['choices']) && count($responseArray['choices']) > 0) { + $choices = []; + foreach ($responseArray['choices'] as $choice) { + $choices[] = Choice::fromArray($choice); + } + + return new ChatResult( + choices: $choices, + messageHistory: $messages->toArray(), + ); + } + + throw new Exception('Error: ' . $response->getContent(false)); + } catch (ClientExceptionInterface | DecodingExceptionInterface | RedirectionExceptionInterface | + ServerExceptionInterface | TransportExceptionInterface $e) { + throw new Exception('API Error: ' . $e->getMessage(), 0, $e); + } + } +}