2025-06-09 15:42:22 +02:00

178 lines
6.2 KiB
PHP

<?php
namespace App\Infrastructure\PostgreSQL\Repository\EmbeddingRepository;
use Doctrine\DBAL\Connection;
use App\Domain\Repository\EmbeddingRepository;
use App\Domain\Model\Embedding\Embedding;
use App\Domain\Model\Embedding\LargeEmbeddingVector;
use App\Domain\Model\Embedding\SmallEmbeddingVector;
use App\Domain\Model\EmbeddingCollection;
use App\Domain\Model\Value\Vector;
use App\Domain\Model\Id\EmbeddingId;
use App\Domain\Model\Id\EmbeddingIdCollection;
final class SqlEmbeddingRepository implements EmbeddingRepository
{
public function __construct(
private readonly Connection $connection,
) {}
public function save(Embedding $embedding): void
{
$this->connection->executeStatement(
'INSERT INTO embeddings (id, phrase_hash, phrase, large_embedding_vector, small_embedding_vector) VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING',
[
$embedding->embeddingId->value,
$embedding->phraseHash(),
$embedding->phrase,
$embedding->largeEmbeddingVector !== null
? '[' . implode(',', $embedding->largeEmbeddingVector->vector->values) . ']'
: null,
$embedding->smallEmbeddingVector !== null
? '[' . implode(',', $embedding->smallEmbeddingVector->vector->values) . ']'
: null,
]
);
}
public function saveAll(EmbeddingCollection $embeddingCollection): void
{
foreach ($embeddingCollection->array() as $embedding) {
$this->save($embedding);
}
}
public function delete(Embedding $embedding): void
{
$this->connection->executeStatement(
'DELETE FROM embeddings WHERE id = ?',
[$embedding->embeddingId->value]
);
}
public function findByPhrase(string $phrase): ?Embedding
{
$result = $this->connection->executeQuery(
'SELECT * FROM embeddings WHERE phrase = ?',
[$phrase]
);
$row = $result->fetchAssociative();
if ($row === false) {
return null;
}
return $this->mapRowToEmbedding($row);
}
public function searchByLargeEmbeddingVector(LargeEmbeddingVector $embeddingVector, int $limit = 20): EmbeddingCollection
{
$result = $this->connection->executeQuery(
'SELECT *, large_embedding_vector <=> :embeddingVector AS distance
FROM embeddings
WHERE large_embedding_vector IS NOT NULL
ORDER BY large_embedding_vector <=> :embeddingVector
LIMIT :limit',
[
'embeddingVector' => '[' . implode(',', $embeddingVector->vector->values) . ']',
'limit' => $limit,
]
);
$embeddings = [];
foreach ($result->fetchAllAssociative() as $row) {
$embeddings[] = $this->mapRowToEmbedding($row);
}
return new EmbeddingCollection($embeddings);
}
public function searchBySmallEmbeddingVector(SmallEmbeddingVector $smallEmbeddingVector, int $limit = 20): EmbeddingCollection
{
$result = $this->connection->executeQuery(
'SELECT *
FROM embeddings
WHERE small_embedding_vector IS NOT NULL
ORDER BY small_embedding_vector <=> ?
LIMIT ?',
[
'[' . implode(',', $smallEmbeddingVector->vector->values) . ']',
$limit,
]
);
$embeddings = [];
foreach ($result->fetchAllAssociative() as $row) {
$embeddings[] = $this->mapRowToEmbedding($row);
}
return new EmbeddingCollection($embeddings);
}
public function findByEmbeddingIdCollection(EmbeddingIdCollection $embeddingIdCollection): EmbeddingCollection
{
$ids = implode(',', array_map(fn(EmbeddingId $id) => $id->value, $embeddingIdCollection->array()));
$result = $this->connection->executeQuery(
"SELECT * FROM embeddings WHERE id IN ($ids)",
);
$embeddings = [];
foreach ($result->fetchAllAssociative() as $row) {
$embeddings[] = $this->mapRowToEmbedding($row);
}
return new EmbeddingCollection($embeddings);
}
public function deleteAll(): void
{
$this->connection->executeStatement('DELETE FROM embeddings');
}
/**
* @param array<string, mixed> $row
*/
private function mapRowToEmbedding(array $row): Embedding
{
$largeEmbeddingVector = null;
if (is_string($row['large_embedding_vector'] ?? null)) {
$largeValues = json_decode($row['large_embedding_vector'], true);
if (is_array($largeValues)) {
$floatValues = [];
foreach ($largeValues as $value) {
if (is_numeric($value)) {
$floatValues[] = (float) $value;
}
}
if (count($floatValues) === count($largeValues)) {
$largeEmbeddingVector = new LargeEmbeddingVector(new Vector($floatValues));
}
}
}
$smallEmbeddingVector = null;
if (is_string($row['small_embedding_vector'] ?? null)) {
$smallValues = json_decode($row['small_embedding_vector'], true);
if (is_array($smallValues)) {
$floatValues = [];
foreach ($smallValues as $value) {
if (is_numeric($value)) {
$floatValues[] = (float) $value;
}
}
if (count($floatValues) === count($smallValues)) {
$smallEmbeddingVector = new SmallEmbeddingVector(new Vector($floatValues));
}
}
}
return new Embedding(
embeddingId: new EmbeddingId(is_string($row['id'] ?? null) ? $row['id'] : throw new \InvalidArgumentException('Embedding ID is required')),
phrase: is_string($row['phrase'] ?? null) ? $row['phrase'] : throw new \InvalidArgumentException('Phrase is required'),
largeEmbeddingVector: $largeEmbeddingVector,
smallEmbeddingVector: $smallEmbeddingVector,
);
}
}