1 package chromem
2
3 import (
4 "bytes"
5 "compress/gzip"
6 "crypto/aes"
7 "crypto/cipher"
8 "crypto/rand"
9 "crypto/sha256"
10 "encoding/gob"
11 "encoding/hex"
12 "errors"
13 "fmt"
14 "io"
15 "io/fs"
16 "os"
17 "path/filepath"
18 )
19
20 const metadataFileName = "00000000"
21
22 func hash2hex(name string) string {
23 hash := sha256.Sum256([]byte(name))
24
25
26
27 return hex.EncodeToString(hash[:4])
28 }
29
30
31
32
33
34 func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error {
35 if filePath == "" {
36 return fmt.Errorf("file path is empty")
37 }
38
39 if encryptionKey != "" {
40 if len(encryptionKey) != 32 {
41 return errors.New("encryption key must be 32 bytes long")
42 }
43 }
44
45
46
47 fi, err := os.Stat(filePath)
48 if err != nil {
49 if !errors.Is(err, fs.ErrNotExist) {
50 return fmt.Errorf("couldn't get info about the path: %w", err)
51 } else {
52
53 err := os.MkdirAll(filepath.Dir(filePath), 0o700)
54 if err != nil {
55 return fmt.Errorf("couldn't create parent directories to path: %w", err)
56 }
57 }
58 } else if fi.IsDir() {
59 return fmt.Errorf("path is a directory: %s", filePath)
60 }
61
62
63 f, err := os.Create(filePath)
64 if err != nil {
65 return fmt.Errorf("couldn't create file: %w", err)
66 }
67 defer f.Close()
68
69 return persistToWriter(f, obj, compress, encryptionKey)
70 }
71
72
73
74
75
76 func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error {
77
78 if encryptionKey != "" {
79 if len(encryptionKey) != 32 {
80 return errors.New("encryption key must be 32 bytes long")
81 }
82 }
83
84
85
86
87
88
89
90 var chainedWriter io.Writer
91 if encryptionKey == "" {
92 chainedWriter = w
93 } else {
94 chainedWriter = &bytes.Buffer{}
95 }
96
97 var gzw *gzip.Writer
98 var enc *gob.Encoder
99 if compress {
100 gzw = gzip.NewWriter(chainedWriter)
101 enc = gob.NewEncoder(gzw)
102 } else {
103 enc = gob.NewEncoder(chainedWriter)
104 }
105
106
107 if err := enc.Encode(obj); err != nil {
108 return fmt.Errorf("couldn't encode or write object: %w", err)
109 }
110
111
112
113
114
115 if compress {
116 err := gzw.Close()
117 if err != nil {
118 return fmt.Errorf("couldn't close gzip writer: %w", err)
119 }
120 }
121
122
123 if encryptionKey == "" {
124 return nil
125 }
126
127
128 block, err := aes.NewCipher([]byte(encryptionKey))
129 if err != nil {
130 return fmt.Errorf("couldn't create new AES cipher: %w", err)
131 }
132 gcm, err := cipher.NewGCM(block)
133 if err != nil {
134 return fmt.Errorf("couldn't create GCM wrapper: %w", err)
135 }
136 nonce := make([]byte, gcm.NonceSize())
137 if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
138 return fmt.Errorf("couldn't read random bytes for nonce: %w", err)
139 }
140
141 buf := chainedWriter.(*bytes.Buffer)
142 encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil)
143 _, err = w.Write(encrypted)
144 if err != nil {
145 return fmt.Errorf("couldn't write encrypted data: %w", err)
146 }
147
148 return nil
149 }
150
151
152
153
154
155 func readFromFile(filePath string, obj any, encryptionKey string) error {
156 if filePath == "" {
157 return fmt.Errorf("file path is empty")
158 }
159
160 if encryptionKey != "" {
161 if len(encryptionKey) != 32 {
162 return errors.New("encryption key must be 32 bytes long")
163 }
164 }
165
166 r, err := os.Open(filePath)
167 if err != nil {
168 return fmt.Errorf("couldn't open file: %w", err)
169 }
170 defer r.Close()
171
172 return readFromReader(r, obj, encryptionKey)
173 }
174
175
176
177
178
179
180 func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error {
181
182 if encryptionKey != "" {
183 if len(encryptionKey) != 32 {
184 return errors.New("encryption key must be 32 bytes long")
185 }
186 }
187
188
189
190
191
192
193
194
195
196 var chainedReader io.Reader
197
198
199 if encryptionKey != "" {
200 encrypted, err := io.ReadAll(r)
201 if err != nil {
202 return fmt.Errorf("couldn't read from reader: %w", err)
203 }
204 block, err := aes.NewCipher([]byte(encryptionKey))
205 if err != nil {
206 return fmt.Errorf("couldn't create AES cipher: %w", err)
207 }
208 gcm, err := cipher.NewGCM(block)
209 if err != nil {
210 return fmt.Errorf("couldn't create GCM wrapper: %w", err)
211 }
212 nonceSize := gcm.NonceSize()
213 if len(encrypted) < nonceSize {
214 return fmt.Errorf("encrypted data too short")
215 }
216 nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:]
217 data, err := gcm.Open(nil, nonce, ciphertext, nil)
218 if err != nil {
219 return fmt.Errorf("couldn't decrypt data: %w", err)
220 }
221
222 chainedReader = bytes.NewReader(data)
223 } else {
224 chainedReader = r
225 }
226
227
228 magicNumber := make([]byte, 2)
229 _, err := chainedReader.Read(magicNumber)
230 if err != nil {
231 return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err)
232 }
233 var compressed bool
234 if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b {
235 compressed = true
236 }
237
238
239 if s, ok := chainedReader.(io.Seeker); !ok {
240 return fmt.Errorf("reader doesn't support seeking")
241 } else {
242 _, err := s.Seek(0, 0)
243 if err != nil {
244 return fmt.Errorf("couldn't reset reader: %w", err)
245 }
246 }
247
248 if compressed {
249 gzr, err := gzip.NewReader(chainedReader)
250 if err != nil {
251 return fmt.Errorf("couldn't create gzip reader: %w", err)
252 }
253 defer gzr.Close()
254 chainedReader = gzr
255 }
256
257 dec := gob.NewDecoder(chainedReader)
258 err = dec.Decode(obj)
259 if err != nil {
260 return fmt.Errorf("couldn't decode object: %w", err)
261 }
262
263 return nil
264 }
265
266
267 func removeFile(filePath string) error {
268 if filePath == "" {
269 return fmt.Errorf("file path is empty")
270 }
271
272 err := os.Remove(filePath)
273 if err != nil {
274 if !errors.Is(err, fs.ErrNotExist) {
275 return fmt.Errorf("couldn't remove file %q: %w", filePath, err)
276 }
277 }
278
279 return nil
280 }
281
View as plain text