178 lines
6.2 KiB
PHP
178 lines
6.2 KiB
PHP
<?php
|
|
|
|
namespace App\Infrastructure\PostgreSQL\Repository\EmbeddingRepository;
|
|
|
|
use Doctrine\DBAL\Connection;
|
|
use App\Domain\Repository\EmbeddingRepository;
|
|
use App\Domain\Model\Embedding\Embedding;
|
|
use App\Domain\Model\Embedding\LargeEmbeddingVector;
|
|
use App\Domain\Model\Embedding\SmallEmbeddingVector;
|
|
use App\Domain\Model\EmbeddingCollection;
|
|
use App\Domain\Model\Value\Vector;
|
|
use App\Domain\Model\Id\EmbeddingId;
|
|
use App\Domain\Model\Id\EmbeddingIdCollection;
|
|
|
|
final class SqlEmbeddingRepository implements EmbeddingRepository
|
|
{
|
|
public function __construct(
|
|
private readonly Connection $connection,
|
|
) {}
|
|
|
|
public function save(Embedding $embedding): void
|
|
{
|
|
$this->connection->executeStatement(
|
|
'INSERT INTO embeddings (id, phrase_hash, phrase, large_embedding_vector, small_embedding_vector) VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING',
|
|
[
|
|
$embedding->embeddingId->value,
|
|
$embedding->phraseHash(),
|
|
$embedding->phrase,
|
|
$embedding->largeEmbeddingVector !== null
|
|
? '[' . implode(',', $embedding->largeEmbeddingVector->vector->values) . ']'
|
|
: null,
|
|
$embedding->smallEmbeddingVector !== null
|
|
? '[' . implode(',', $embedding->smallEmbeddingVector->vector->values) . ']'
|
|
: null,
|
|
]
|
|
);
|
|
}
|
|
|
|
public function saveAll(EmbeddingCollection $embeddingCollection): void
|
|
{
|
|
foreach ($embeddingCollection->array() as $embedding) {
|
|
$this->save($embedding);
|
|
}
|
|
}
|
|
|
|
public function delete(Embedding $embedding): void
|
|
{
|
|
$this->connection->executeStatement(
|
|
'DELETE FROM embeddings WHERE id = ?',
|
|
[$embedding->embeddingId->value]
|
|
);
|
|
}
|
|
|
|
public function findByPhrase(string $phrase): ?Embedding
|
|
{
|
|
$result = $this->connection->executeQuery(
|
|
'SELECT * FROM embeddings WHERE phrase = ?',
|
|
[$phrase]
|
|
);
|
|
|
|
$row = $result->fetchAssociative();
|
|
if ($row === false) {
|
|
return null;
|
|
}
|
|
|
|
return $this->mapRowToEmbedding($row);
|
|
}
|
|
|
|
public function searchByLargeEmbeddingVector(LargeEmbeddingVector $embeddingVector, int $limit = 20): EmbeddingCollection
|
|
{
|
|
$result = $this->connection->executeQuery(
|
|
'SELECT *, large_embedding_vector <=> :embeddingVector AS distance
|
|
FROM embeddings
|
|
WHERE large_embedding_vector IS NOT NULL
|
|
ORDER BY large_embedding_vector <=> :embeddingVector
|
|
LIMIT :limit',
|
|
[
|
|
'embeddingVector' => '[' . implode(',', $embeddingVector->vector->values) . ']',
|
|
'limit' => $limit,
|
|
]
|
|
);
|
|
|
|
$embeddings = [];
|
|
foreach ($result->fetchAllAssociative() as $row) {
|
|
$embeddings[] = $this->mapRowToEmbedding($row);
|
|
}
|
|
|
|
return new EmbeddingCollection($embeddings);
|
|
}
|
|
|
|
public function searchBySmallEmbeddingVector(SmallEmbeddingVector $smallEmbeddingVector, int $limit = 20): EmbeddingCollection
|
|
{
|
|
$result = $this->connection->executeQuery(
|
|
'SELECT *
|
|
FROM embeddings
|
|
WHERE small_embedding_vector IS NOT NULL
|
|
ORDER BY small_embedding_vector <=> ?
|
|
LIMIT ?',
|
|
[
|
|
'[' . implode(',', $smallEmbeddingVector->vector->values) . ']',
|
|
$limit,
|
|
]
|
|
);
|
|
|
|
$embeddings = [];
|
|
foreach ($result->fetchAllAssociative() as $row) {
|
|
$embeddings[] = $this->mapRowToEmbedding($row);
|
|
}
|
|
|
|
return new EmbeddingCollection($embeddings);
|
|
}
|
|
|
|
public function findByEmbeddingIdCollection(EmbeddingIdCollection $embeddingIdCollection): EmbeddingCollection
|
|
{
|
|
$ids = implode(',', array_map(fn(EmbeddingId $id) => $id->value, $embeddingIdCollection->array()));
|
|
|
|
$result = $this->connection->executeQuery(
|
|
"SELECT * FROM embeddings WHERE id IN ($ids)",
|
|
);
|
|
|
|
$embeddings = [];
|
|
foreach ($result->fetchAllAssociative() as $row) {
|
|
$embeddings[] = $this->mapRowToEmbedding($row);
|
|
}
|
|
|
|
return new EmbeddingCollection($embeddings);
|
|
}
|
|
|
|
public function deleteAll(): void
|
|
{
|
|
$this->connection->executeStatement('DELETE FROM embeddings');
|
|
}
|
|
|
|
/**
|
|
* @param array<string, mixed> $row
|
|
*/
|
|
private function mapRowToEmbedding(array $row): Embedding
|
|
{
|
|
$largeEmbeddingVector = null;
|
|
if (is_string($row['large_embedding_vector'] ?? null)) {
|
|
$largeValues = json_decode($row['large_embedding_vector'], true);
|
|
if (is_array($largeValues)) {
|
|
$floatValues = [];
|
|
foreach ($largeValues as $value) {
|
|
if (is_numeric($value)) {
|
|
$floatValues[] = (float) $value;
|
|
}
|
|
}
|
|
if (count($floatValues) === count($largeValues)) {
|
|
$largeEmbeddingVector = new LargeEmbeddingVector(new Vector($floatValues));
|
|
}
|
|
}
|
|
}
|
|
|
|
$smallEmbeddingVector = null;
|
|
if (is_string($row['small_embedding_vector'] ?? null)) {
|
|
$smallValues = json_decode($row['small_embedding_vector'], true);
|
|
if (is_array($smallValues)) {
|
|
$floatValues = [];
|
|
foreach ($smallValues as $value) {
|
|
if (is_numeric($value)) {
|
|
$floatValues[] = (float) $value;
|
|
}
|
|
}
|
|
if (count($floatValues) === count($smallValues)) {
|
|
$smallEmbeddingVector = new SmallEmbeddingVector(new Vector($floatValues));
|
|
}
|
|
}
|
|
}
|
|
|
|
return new Embedding(
|
|
embeddingId: new EmbeddingId(is_string($row['id'] ?? null) ? $row['id'] : throw new \InvalidArgumentException('Embedding ID is required')),
|
|
phrase: is_string($row['phrase'] ?? null) ? $row['phrase'] : throw new \InvalidArgumentException('Phrase is required'),
|
|
largeEmbeddingVector: $largeEmbeddingVector,
|
|
smallEmbeddingVector: $smallEmbeddingVector,
|
|
);
|
|
}
|
|
} |