1/*
2 * Copyright (C) 2019 Yaoyuan <ibireme@gmail.com>.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * Copyright 2018-2023 The simdjson authors
17 *
18 * Licensed under the Apache License, Version 2.0 (the "License");
19 * you may not use this file except in compliance with the License.
20 * You may obtain a copy of the License at
21
22 * http://www.apache.org/licenses/LICENSE-2.0
23
24 * Unless required by applicable law or agreed to in writing, software
25 * distributed under the License is distributed on an "AS IS" BASIS,
26 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 * See the License for the specific language governing permissions and
28 * limitations under the License.
29 *
30 * This file may have been modified by ByteDance authors. All ByteDance
31 * Modifications are Copyright 2022 ByteDance Authors.
32 */
33
34#pragma once
35
36#include "native.h"
37#include "utils.h"
38#include "test/xassert.h"
39#include "test/xprintf.h"
40
41static inline ssize_t valid_utf8_4byte(uint32_t ubin) {
42 /*
43 Each unicode code point is encoded as 1 to 4 bytes in UTF-8 encoding,
44 we use 4-byte mask and pattern value to validate UTF-8 byte sequence,
45 this requires the input data to have 4-byte zero padding.
46 ---------------------------------------------------
47 1 byte
48 unicode range [U+0000, U+007F]
49 unicode min [.......0]
50 unicode max [.1111111]
51 bit pattern [0.......]
52 ---------------------------------------------------
53 2 byte
54 unicode range [U+0080, U+07FF]
55 unicode min [......10 ..000000]
56 unicode max [...11111 ..111111]
57 bit require [...xxxx. ........] (1E 00)
58 bit mask [xxx..... xx......] (E0 C0)
59 bit pattern [110..... 10......] (C0 80)
60 // 1101 0100 10110000
61 // 0001 1110
62 ---------------------------------------------------
63 3 byte
64 unicode range [U+0800, U+FFFF]
65 unicode min [........ ..100000 ..000000]
66 unicode max [....1111 ..111111 ..111111]
67 bit require [....xxxx ..x..... ........] (0F 20 00)
68 bit mask [xxxx.... xx...... xx......] (F0 C0 C0)
69 bit pattern [1110.... 10...... 10......] (E0 80 80)
70 ---------------------------------------------------
71 3 byte invalid (reserved for surrogate halves)
72 unicode range [U+D800, U+DFFF]
73 unicode min [....1101 ..100000 ..000000]
74 unicode max [....1101 ..111111 ..111111]
75 bit mask [....xxxx ..x..... ........] (0F 20 00)
76 bit pattern [....1101 ..1..... ........] (0D 20 00)
77 ---------------------------------------------------
78 4 byte
79 unicode range [U+10000, U+10FFFF]
80 unicode min [........ ...10000 ..000000 ..000000]
81 unicode max [.....100 ..001111 ..111111 ..111111]
82 bit err0 [.....100 ........ ........ ........] (04 00 00 00)
83 bit err1 [.....011 ..110000 ........ ........] (03 30 00 00)
84 bit require [.....xxx ..xx.... ........ ........] (07 30 00 00)
85 bit mask [xxxxx... xx...... xx...... xx......] (F8 C0 C0 C0)
86 bit pattern [11110... 10...... 10...... 10......] (F0 80 80 80)
87 ---------------------------------------------------
88 */
89 const uint32_t b2_mask = 0x0000C0E0UL;
90 const uint32_t b2_patt = 0x000080C0UL;
91 const uint32_t b2_requ = 0x0000001EUL;
92 const uint32_t b3_mask = 0x00C0C0F0UL;
93 const uint32_t b3_patt = 0x008080E0UL;
94 const uint32_t b3_requ = 0x0000200FUL;
95 const uint32_t b3_erro = 0x0000200DUL;
96 const uint32_t b4_mask = 0xC0C0C0F8UL;
97 const uint32_t b4_patt = 0x808080F0UL;
98 const uint32_t b4_requ = 0x00003007UL;
99 const uint32_t b4_err0 = 0x00000004UL;
100 const uint32_t b4_err1 = 0x00003003UL;
101
102#define is_valid_seq_2(uni) ( \
103 ((uni & b2_mask) == b2_patt) && \
104 ((uni & b2_requ)) \
105)
106
107#define is_valid_seq_3(uni) ( \
108 ((uni & b3_mask) == b3_patt) && \
109 ((tmp = (uni & b3_requ))) && \
110 ((tmp != b3_erro)) \
111)
112
113#define is_valid_seq_4(uni) ( \
114 ((uni & b4_mask) == b4_patt) && \
115 ((tmp = (uni & b4_requ))) && \
116 ((tmp & b4_err0) == 0 || (tmp & b4_err1) == 0) \
117)
118 uint32_t tmp = 0;
119
120 if (is_valid_seq_3(ubin)) return 3;
121 if (is_valid_seq_2(ubin)) return 2;
122 if (is_valid_seq_4(ubin)) return 4;
123 return 0;
124}
125
126static always_inline long write_error(int pos, StateMachine *m, size_t msize) {
127 if (m->sp >= msize) {
128 return -1;
129 }
130 m->vt[m->sp++] = pos;
131 return 0;
132}
133
134// scalar code, error position should excesss 4096
135static always_inline long validate_utf8_with_errors(const char *src, long len, long *p, StateMachine *m) {
136 const char* start = src + *p;
137 const char* end = src + len;
138 while (start < end - 3) {
139 uint32_t u = (*(uint32_t*)(start));
140 if ((unsigned)(*start) < 0x80) {
141 start += 1;
142 continue;
143 }
144 size_t n = valid_utf8_4byte(u);
145 if (n != 0) { // valid utf
146 start += n;
147 continue;
148 }
149 long err = write_error(start - src, m, MAX_RECURSE);
150 if (err) {
151 *p = start - src;
152 return err;
153 }
154 start += 1;
155 }
156 while (start < end) {
157 if ((unsigned)(*start) < 0x80) {
158 start += 1;
159 continue;
160 }
161 uint32_t u = 0;
162 memcpy_p4(&u, start, end - start);
163 size_t n = valid_utf8_4byte(u);
164 if (n != 0) { // valid utf
165 start += n;
166 continue;
167 }
168 long err = write_error(start - src, m, MAX_RECURSE);
169 if (err) {
170 *p = start - src;
171 return err;
172 }
173 start += 1;
174 }
175 *p = start - src;
176 return 0;
177}
178
179// validate_utf8_errors returns zero if valid, otherwise, the error position.
180static always_inline long validate_utf8_errors(const GoString* s) {
181 const char* start = s->buf;
182 const char* end = s->buf + s->len;
183 while (start < end - 3) {
184 uint32_t u = (*(uint32_t*)(start));
185 if ((unsigned)(*start) < 0x80) {
186 start += 1;
187 continue;
188 }
189 size_t n = valid_utf8_4byte(u);
190 if (n == 0) { // invalid utf
191 return -(start - s->buf) - 1;
192 }
193 start += n;
194 }
195 while (start < end) {
196 if ((unsigned)(*start) < 0x80) {
197 start += 1;
198 continue;
199 }
200 uint32_t u = 0;
201 memcpy_p4(&u, start, end - start);
202 size_t n = valid_utf8_4byte(u);
203 if (n == 0) { // invalid utf
204 return -(start - s->buf) - 1;
205 }
206 start += n;
207 }
208 return 0;
209}
210
211// SIMD implementation
212#if USE_AVX2
213
214 static always_inline __m256i simd256_shr(const __m256i input, const int shift) {
215 __m256i shifted = _mm256_srli_epi16(input, shift);
216 __m256i mask = _mm256_set1_epi8(0xFFu >> shift);
217 return _mm256_and_si256(shifted, mask);
218 }
219
220#define simd256_prev(input, prev, N) _mm256_alignr_epi8(input, _mm256_permute2x128_si256(prev, input, 0x21), 16 - (N));
221
222 static always_inline __m256i must_be_2_3_continuation(const __m256i prev2, const __m256i prev3) {
223 __m256i is_third_byte = _mm256_subs_epu8(prev2, _mm256_set1_epi8(0b11100000u-1)); // Only 111_____ will be > 0
224 __m256i is_fourth_byte = _mm256_subs_epu8(prev3, _mm256_set1_epi8(0b11110000u-1)); // Only 1111____ will be > 0
225 // Caller requires a bool (all 1's). All values resulting from the subtraction will be <= 64, so signed comparison is fine.
226 __m256i or = _mm256_or_si256(is_third_byte, is_fourth_byte);
227 return _mm256_cmpgt_epi8(or, _mm256_set1_epi8(0));;
228 }
229
230 static always_inline __m256i simd256_lookup16(const __m256i input, const uint8_t* table) {
231 return _mm256_shuffle_epi8(_mm256_setr_epi8(table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15], table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15]), input);
232 }
233
234 //
235 // Return nonzero if there are incomplete multibyte characters at the end of the block:
236 // e.g. if there is a 4-byte character, but it's 3 bytes from the end.
237 //
238 static always_inline __m256i is_incomplete(const __m256i input) {
239 // If the previous input's last 3 bytes match this, they're too short (they ended at EOF):
240 // ... 1111____ 111_____ 11______
241 const uint8_t tab[32] = {
242 255, 255, 255, 255, 255, 255, 255, 255,
243 255, 255, 255, 255, 255, 255, 255, 255,
244 255, 255, 255, 255, 255, 255, 255, 255,
245 255, 255, 255, 255, 255, 0b11110000u-1, 0b11100000u-1, 0b11000000u-1};
246 const __m256i max_value = _mm256_loadu_si256((const __m256i_u *)(&tab[0]));
247 return _mm256_subs_epu8(input, max_value);
248 }
249
250 static always_inline __m256i check_special_cases(const __m256i input, const __m256i prev1) {
251 // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII)
252 // Bit 1 = Too Long (ASCII followed by continuation)
253 // Bit 2 = Overlong 3-byte
254 // Bit 4 = Surrogate
255 // Bit 5 = Overlong 2-byte
256 // Bit 7 = Two Continuations
257 const uint8_t TOO_SHORT = 1<<0; // 11______ 0_______
258 // 11______ 11______
259 const uint8_t TOO_LONG = 1<<1; // 0_______ 10______
260 const uint8_t OVERLONG_3 = 1<<2; // 11100000 100_____
261 const uint8_t SURROGATE = 1<<4; // 11101101 101_____
262 const uint8_t OVERLONG_2 = 1<<5; // 1100000_ 10______
263 const uint8_t TWO_CONTS = 1<<7; // 10______ 10______
264 const uint8_t TOO_LARGE = 1<<3; // 11110100 1001____
265 // 11110100 101_____
266 // 11110101 1001____
267 // 11110101 101_____
268 // 1111011_ 1001____
269 // 1111011_ 101_____
270 // 11111___ 1001____
271 // 11111___ 101_____
272 const uint8_t TOO_LARGE_1000 = 1<<6;
273 // 11110101 1000____
274 // 1111011_ 1000____
275 // 11111___ 1000____
276 const uint8_t OVERLONG_4 = 1<<6; // 11110000 1000____
277
278 const __m256i prev1_shr4 = simd256_shr(prev1, 4);
279 static const uint8_t tab1[16] = {
280 // 0_______ ________ <ASCII in byte 1>
281 TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
282 TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
283 // 10______ ________ <continuation in byte 1>
284 TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
285 // 1100____ ________ <two byte lead in byte 1>
286 TOO_SHORT | OVERLONG_2,
287 // 1101____ ________ <two byte lead in byte 1>
288 TOO_SHORT,
289 // 1110____ ________ <three byte lead in byte 1>
290 TOO_SHORT | OVERLONG_3 | SURROGATE,
291 // 1111____ ________ <four+ byte lead in byte 1>
292 TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4,
293 };
294 __m256i byte_1_high = simd256_lookup16(prev1_shr4, tab1);
295
296
297 const uint8_t CARRY = TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 .
298 __m256i prev1_low = _mm256_and_si256(prev1, _mm256_set1_epi8(0x0F));
299 static const uint8_t tab2[16] = {
300 // ____0000 ________
301 CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
302 // ____0001 ________
303 CARRY | OVERLONG_2,
304 // ____001_ ________
305 CARRY,
306 CARRY,
307
308 // ____0100 ________
309 CARRY | TOO_LARGE,
310 // ____0101 ________
311 CARRY | TOO_LARGE | TOO_LARGE_1000,
312 // ____011_ ________
313 CARRY | TOO_LARGE | TOO_LARGE_1000,
314 CARRY | TOO_LARGE | TOO_LARGE_1000,
315
316 // ____1___ ________
317 CARRY | TOO_LARGE | TOO_LARGE_1000,
318 CARRY | TOO_LARGE | TOO_LARGE_1000,
319 CARRY | TOO_LARGE | TOO_LARGE_1000,
320 CARRY | TOO_LARGE | TOO_LARGE_1000,
321 CARRY | TOO_LARGE | TOO_LARGE_1000,
322 // ____1101 ________
323 CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
324 CARRY | TOO_LARGE | TOO_LARGE_1000,
325 CARRY | TOO_LARGE | TOO_LARGE_1000
326 };
327 __m256i byte_1_low = simd256_lookup16(prev1_low, tab2);
328
329
330 const __m256i input_shr4 = simd256_shr(input, 4);
331 static const uint8_t tab3[16] = {
332 // ________ 0_______ <ASCII in byte 2>
333 TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
334 TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
335
336 // ________ 1000____
337 TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
338 // ________ 1001____
339 TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
340 // ________ 101_____
341 TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE,
342 TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE,
343
344 // ________ 11______
345 TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT
346 };
347 __m256i byte_2_high = simd256_lookup16(input_shr4, tab3);
348
349
350 return _mm256_and_si256(_mm256_and_si256(byte_1_high, byte_1_low), byte_2_high);
351 }
352
353 static always_inline __m256i check_multibyte_lengths(const __m256i input, const __m256i prev_input, const __m256i sc) {
354 __m256i prev2 = simd256_prev(input, prev_input, 2);
355 __m256i prev3 = simd256_prev(input, prev_input, 3);
356
357
358 __m256i must23 = must_be_2_3_continuation(prev2, prev3);
359
360 __m256i must23_80 = _mm256_and_si256(must23, _mm256_set1_epi8(0x80));
361
362 return _mm256_xor_si256(must23_80, sc);
363 }
364
365
366 // Check whether the current bytes are valid UTF-8.
367 static always_inline __m256i check_utf8_bytes(const __m256i input, const __m256i prev_input) {
368 // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or 4+ lead bytes
369 // (2, 3, 4-byte leads become large positive numbers instead of small negative numbers)
370 __m256i prev1 = simd256_prev(input, prev_input, 1);
371 __m256i sc = check_special_cases(input, prev1);
372 __m256i ret = check_multibyte_lengths(input, prev_input, sc);
373 return ret;
374 }
375
376 static always_inline bool is_ascii(const __m256i input) {
377 return _mm256_movemask_epi8(input) == 0;
378 }
379
380 typedef struct {
381 // If this is nonzero, there has been a UTF-8 error.
382 __m256i error;
383 // The last input we received
384 __m256i prev_input_block;
385 // Whether the last input we received was incomplete (used for ASCII fast path)
386 __m256i prev_incomplete;
387 } utf8_checker;
388
389 static always_inline void utf8_checker_init(utf8_checker* checker) {
390 checker->error = _mm256_setzero_si256();
391 checker->prev_input_block = _mm256_setzero_si256();
392 checker->prev_incomplete = _mm256_setzero_si256();
393 }
394
395 static always_inline bool check_error(utf8_checker* checker) {
396 return !_mm256_testz_si256(checker->error, checker->error);
397 }
398
399 static always_inline void check64_utf(utf8_checker* checker, const uint8_t* start) {
400 __m256i input = _mm256_loadu_si256((__m256i*)start);
401 __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
402 // check utf-8 chars
403 __m256i error1 = check_utf8_bytes(input, checker->prev_input_block);
404 __m256i error2 = check_utf8_bytes(input2, input);
405 checker->error = _mm256_or_si256(checker->error, _mm256_or_si256(error1, error2));
406 checker->prev_input_block = input2;
407 checker->prev_incomplete = is_incomplete(input2);
408 }
409
410 static always_inline void check64(utf8_checker* checker, const uint8_t* start) {
411 // fast path for contiguous ASCII
412 __m256i input = _mm256_loadu_si256((__m256i*)start);
413 __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
414 __m256i reducer = _mm256_or_si256(input, input2);
415 // check utf-8
416 if (likely(is_ascii(reducer))) {
417 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
418 return;
419 }
420 check64_utf(checker, start);
421 }
422
423 static always_inline void check128(utf8_checker* checker, const uint8_t* start) {
424 // fast path for contiguous ASCII
425 __m256i input = _mm256_loadu_si256((__m256i*)start);
426 __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
427 __m256i input3 = _mm256_loadu_si256((__m256i*)(start + 64));
428 __m256i input4 = _mm256_loadu_si256((__m256i*)(start + 96));
429
430 __m256i reducer1 = _mm256_or_si256(input, input2);
431 __m256i reducer2 = _mm256_or_si256(input3, input4);
432 __m256i reducer = _mm256_or_si256(reducer1, reducer2);
433
434 // full 128 bytes are ascii
435 if (likely(is_ascii(reducer))) {
436 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
437 return;
438 }
439
440 // frist 64 bytes is ascii, next 64 bytes must be utf8
441 if (likely(is_ascii(reducer1))) {
442 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
443 check64_utf(checker, start + 64);
444 return;
445 }
446
447 // frist 64 bytes has utf8, next 64 bytes
448 check64_utf(checker, start);
449 if (unlikely(is_ascii(reducer2))) {
450 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
451 } else {
452 check64_utf(checker, start + 64);
453 }
454 }
455
456 static always_inline void check_eof(utf8_checker* checker) {
457 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
458 }
459
460 static always_inline void check_remain(utf8_checker* checker, const uint8_t* start, const uint8_t* end) {
461 uint8_t buffer[64] = {0};
462 int i = 0;
463 while (start < end) {
464 buffer[i++] = *(start++);
465 };
466 check64(checker, buffer);
467 check_eof(checker);
468 }
469
470 static always_inline long validate_utf8_avx2(const GoString* s) {
471 xassert(s->buf != NULL || s->len != 0);
472 const uint8_t* start = (const uint8_t*)(s->buf);
473 const uint8_t* end = (const uint8_t*)(s->buf + s->len);
474 /* check eof */
475 if (s->len == 0) {
476 return 0;
477 }
478 utf8_checker checker;
479 utf8_checker_init(&checker);
480 while (start < (end - 128)) {
481 check128(&checker, start);
482 if (check_error(&checker)) {
483 }
484 start += 128;
485 };
486 while (start < end - 64) {
487 check64(&checker, start);
488 start += 64;
489 }
490 check_remain(&checker, start, end);
491 return check_error(&checker) ? -1 : 0;
492 }
493#endif
View as plain text