...

Text file src/github.com/bytedance/sonic/native/utf8.h

Documentation: github.com/bytedance/sonic/native

     1/*
     2 * Copyright (C) 2019 Yaoyuan <ibireme@gmail.com>.
     3 *
     4 * Licensed under the Apache License, Version 2.0 (the "License");
     5 * you may not use this file except in compliance with the License.
     6 * You may obtain a copy of the License at
     7 *
     8 *     http://www.apache.org/licenses/LICENSE-2.0
     9 *
    10 * Unless required by applicable law or agreed to in writing, software
    11 * distributed under the License is distributed on an "AS IS" BASIS,
    12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13 * See the License for the specific language governing permissions and
    14 * limitations under the License.
    15 *
    16 * Copyright 2018-2023 The simdjson authors
    17 *
    18 * Licensed under the Apache License, Version 2.0 (the "License");
    19 * you may not use this file except in compliance with the License.
    20 * You may obtain a copy of the License at
    21
    22 *     http://www.apache.org/licenses/LICENSE-2.0
    23
    24 * Unless required by applicable law or agreed to in writing, software
    25 * distributed under the License is distributed on an "AS IS" BASIS,
    26 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    27 * See the License for the specific language governing permissions and
    28 * limitations under the License.
    29 * 
    30 * This file may have been modified by ByteDance authors. All ByteDance
    31 * Modifications are Copyright 2022 ByteDance Authors.
    32 */
    33
    34#pragma once
    35
    36#include "native.h"
    37#include "utils.h"
    38#include "test/xassert.h"
    39#include "test/xprintf.h"
    40
    41static inline ssize_t valid_utf8_4byte(uint32_t ubin) {
    42    /*
    43     Each unicode code point is encoded as 1 to 4 bytes in UTF-8 encoding,
    44     we use 4-byte mask and pattern value to validate UTF-8 byte sequence,
    45     this requires the input data to have 4-byte zero padding.
    46     ---------------------------------------------------
    47     1 byte
    48     unicode range [U+0000, U+007F]
    49     unicode min   [.......0]
    50     unicode max   [.1111111]
    51     bit pattern   [0.......]
    52     ---------------------------------------------------
    53     2 byte
    54     unicode range [U+0080, U+07FF]
    55     unicode min   [......10 ..000000]
    56     unicode max   [...11111 ..111111]
    57     bit require   [...xxxx. ........] (1E 00)
    58     bit mask      [xxx..... xx......] (E0 C0)
    59     bit pattern   [110..... 10......] (C0 80)
    60     // 1101 0100 10110000
    61     // 0001 1110
    62     ---------------------------------------------------
    63     3 byte
    64     unicode range [U+0800, U+FFFF]
    65     unicode min   [........ ..100000 ..000000]
    66     unicode max   [....1111 ..111111 ..111111]
    67     bit require   [....xxxx ..x..... ........] (0F 20 00)
    68     bit mask      [xxxx.... xx...... xx......] (F0 C0 C0)
    69     bit pattern   [1110.... 10...... 10......] (E0 80 80)
    70     ---------------------------------------------------
    71     3 byte invalid (reserved for surrogate halves)
    72     unicode range [U+D800, U+DFFF]
    73     unicode min   [....1101 ..100000 ..000000]
    74     unicode max   [....1101 ..111111 ..111111]
    75     bit mask      [....xxxx ..x..... ........] (0F 20 00)
    76     bit pattern   [....1101 ..1..... ........] (0D 20 00)
    77     ---------------------------------------------------
    78     4 byte
    79     unicode range [U+10000, U+10FFFF]
    80     unicode min   [........ ...10000 ..000000 ..000000]
    81     unicode max   [.....100 ..001111 ..111111 ..111111]
    82     bit err0      [.....100 ........ ........ ........] (04 00 00 00)
    83     bit err1      [.....011 ..110000 ........ ........] (03 30 00 00)
    84     bit require   [.....xxx ..xx.... ........ ........] (07 30 00 00)
    85     bit mask      [xxxxx... xx...... xx...... xx......] (F8 C0 C0 C0)
    86     bit pattern   [11110... 10...... 10...... 10......] (F0 80 80 80)
    87     ---------------------------------------------------
    88     */
    89    const uint32_t b2_mask = 0x0000C0E0UL;
    90    const uint32_t b2_patt = 0x000080C0UL;
    91    const uint32_t b2_requ = 0x0000001EUL;
    92    const uint32_t b3_mask = 0x00C0C0F0UL;
    93    const uint32_t b3_patt = 0x008080E0UL;
    94    const uint32_t b3_requ = 0x0000200FUL;
    95    const uint32_t b3_erro = 0x0000200DUL;
    96    const uint32_t b4_mask = 0xC0C0C0F8UL;
    97    const uint32_t b4_patt = 0x808080F0UL;
    98    const uint32_t b4_requ = 0x00003007UL;
    99    const uint32_t b4_err0 = 0x00000004UL;
   100    const uint32_t b4_err1 = 0x00003003UL;
   101
   102#define is_valid_seq_2(uni) ( \
   103    ((uni & b2_mask) == b2_patt) && \
   104    ((uni & b2_requ)) \
   105)
   106    
   107#define is_valid_seq_3(uni) ( \
   108    ((uni & b3_mask) == b3_patt) && \
   109    ((tmp = (uni & b3_requ))) && \
   110    ((tmp != b3_erro)) \
   111)
   112    
   113#define is_valid_seq_4(uni) ( \
   114    ((uni & b4_mask) == b4_patt) && \
   115    ((tmp = (uni & b4_requ))) && \
   116    ((tmp & b4_err0) == 0 || (tmp & b4_err1) == 0) \
   117)
   118    uint32_t tmp = 0;
   119   
   120    if (is_valid_seq_3(ubin)) return 3;
   121    if (is_valid_seq_2(ubin)) return 2;
   122    if (is_valid_seq_4(ubin)) return 4;
   123    return 0;
   124}
   125
   126static always_inline long write_error(int pos, StateMachine *m, size_t msize) {
   127    if (m->sp >= msize) {
   128        return -1;
   129    }
   130    m->vt[m->sp++] = pos;
   131    return 0;
   132}
   133
   134// scalar code, error position should excesss 4096
   135static always_inline long validate_utf8_with_errors(const char *src, long len, long *p, StateMachine *m) {
   136    const char* start = src + *p;
   137    const char* end = src + len;
   138    while (start < end - 3) {
   139        uint32_t u = (*(uint32_t*)(start));
   140        if ((unsigned)(*start) < 0x80) {
   141            start += 1;
   142            continue;
   143        }
   144        size_t n = valid_utf8_4byte(u);
   145        if (n != 0) { // valid utf
   146            start += n;
   147            continue;
   148        }
   149        long err = write_error(start - src, m, MAX_RECURSE);
   150        if (err) {
   151            *p = start - src;
   152            return err;
   153        }
   154        start += 1;
   155    }
   156    while (start < end) {
   157        if ((unsigned)(*start) < 0x80) {
   158            start += 1;
   159            continue;
   160        }
   161        uint32_t u = 0;
   162        memcpy_p4(&u, start, end - start);
   163        size_t n = valid_utf8_4byte(u);
   164        if (n != 0) { // valid utf
   165            start += n;
   166            continue;
   167        }
   168        long err = write_error(start - src, m, MAX_RECURSE);
   169        if (err) {
   170            *p = start - src;
   171            return err;
   172        }
   173        start += 1;
   174    }
   175    *p = start - src;
   176    return 0;
   177}
   178
   179// validate_utf8_errors returns zero if valid, otherwise, the error position.
   180static always_inline long validate_utf8_errors(const GoString* s) {
   181    const char* start = s->buf;
   182    const char* end = s->buf + s->len;
   183    while (start < end - 3) {
   184        uint32_t u = (*(uint32_t*)(start));
   185        if ((unsigned)(*start) < 0x80) {
   186            start += 1;
   187            continue;
   188        }
   189        size_t n = valid_utf8_4byte(u);
   190        if (n == 0) { // invalid utf
   191            return -(start - s->buf) - 1;
   192        }
   193        start += n;
   194    }
   195    while (start < end) {
   196        if ((unsigned)(*start) < 0x80) {
   197            start += 1;
   198            continue;
   199        }
   200        uint32_t u = 0;
   201        memcpy_p4(&u, start, end - start);
   202        size_t n = valid_utf8_4byte(u);
   203        if (n == 0) { // invalid utf
   204            return -(start - s->buf) - 1;
   205        }
   206        start += n;
   207    }
   208    return 0;
   209}
   210
   211// SIMD implementation
   212#if USE_AVX2
   213
   214    static always_inline __m256i simd256_shr(const __m256i input, const int shift) {
   215        __m256i shifted = _mm256_srli_epi16(input, shift);
   216        __m256i mask = _mm256_set1_epi8(0xFFu >> shift);
   217        return _mm256_and_si256(shifted, mask);
   218    }
   219
   220#define simd256_prev(input, prev, N) _mm256_alignr_epi8(input, _mm256_permute2x128_si256(prev, input, 0x21), 16 - (N));
   221
   222    static always_inline __m256i must_be_2_3_continuation(const __m256i prev2, const __m256i prev3) {
   223        __m256i is_third_byte  = _mm256_subs_epu8(prev2, _mm256_set1_epi8(0b11100000u-1)); // Only 111_____ will be > 0
   224        __m256i is_fourth_byte = _mm256_subs_epu8(prev3, _mm256_set1_epi8(0b11110000u-1)); // Only 1111____ will be > 0
   225        // Caller requires a bool (all 1's). All values resulting from the subtraction will be <= 64, so signed comparison is fine.
   226        __m256i or = _mm256_or_si256(is_third_byte, is_fourth_byte);
   227        return _mm256_cmpgt_epi8(or, _mm256_set1_epi8(0));;
   228    }
   229
   230    static always_inline __m256i simd256_lookup16(const __m256i input, const uint8_t* table) {
   231        return _mm256_shuffle_epi8(_mm256_setr_epi8(table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15], table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15]), input);
   232    }
   233
   234  //
   235  // Return nonzero if there are incomplete multibyte characters at the end of the block:
   236  // e.g. if there is a 4-byte character, but it's 3 bytes from the end.
   237  //
   238      static always_inline  __m256i is_incomplete(const __m256i input) {
   239    // If the previous input's last 3 bytes match this, they're too short (they ended at EOF):
   240    // ... 1111____ 111_____ 11______
   241      const uint8_t tab[32] = {
   242      255, 255, 255, 255, 255, 255, 255, 255,
   243      255, 255, 255, 255, 255, 255, 255, 255,
   244      255, 255, 255, 255, 255, 255, 255, 255,
   245      255, 255, 255, 255, 255, 0b11110000u-1, 0b11100000u-1, 0b11000000u-1};
   246        const __m256i max_value = _mm256_loadu_si256((const __m256i_u *)(&tab[0]));
   247        return _mm256_subs_epu8(input, max_value);
   248    }
   249
   250  static always_inline __m256i check_special_cases(const __m256i input, const __m256i prev1) {
   251    // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII)
   252    // Bit 1 = Too Long (ASCII followed by continuation)
   253    // Bit 2 = Overlong 3-byte
   254    // Bit 4 = Surrogate
   255    // Bit 5 = Overlong 2-byte
   256    // Bit 7 = Two Continuations
   257     const uint8_t TOO_SHORT   = 1<<0; // 11______ 0_______
   258                                                // 11______ 11______
   259     const uint8_t TOO_LONG    = 1<<1; // 0_______ 10______
   260     const uint8_t OVERLONG_3  = 1<<2; // 11100000 100_____
   261     const uint8_t SURROGATE   = 1<<4; // 11101101 101_____
   262     const uint8_t OVERLONG_2  = 1<<5; // 1100000_ 10______
   263     const uint8_t TWO_CONTS   = 1<<7; // 10______ 10______
   264     const uint8_t TOO_LARGE   = 1<<3; // 11110100 1001____
   265                                                // 11110100 101_____
   266                                                // 11110101 1001____
   267                                                // 11110101 101_____
   268                                                // 1111011_ 1001____
   269                                                // 1111011_ 101_____
   270                                                // 11111___ 1001____
   271                                                // 11111___ 101_____
   272     const uint8_t TOO_LARGE_1000 = 1<<6;
   273                                                // 11110101 1000____
   274                                                // 1111011_ 1000____
   275                                                // 11111___ 1000____
   276     const uint8_t OVERLONG_4  = 1<<6; // 11110000 1000____
   277
   278    const __m256i prev1_shr4 = simd256_shr(prev1, 4);
   279    static const uint8_t tab1[16] = {
   280              // 0_______ ________ <ASCII in byte 1>
   281      TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
   282      TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
   283      // 10______ ________ <continuation in byte 1>
   284      TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
   285      // 1100____ ________ <two byte lead in byte 1>
   286      TOO_SHORT | OVERLONG_2,
   287      // 1101____ ________ <two byte lead in byte 1>
   288      TOO_SHORT,
   289      // 1110____ ________ <three byte lead in byte 1>
   290      TOO_SHORT | OVERLONG_3 | SURROGATE,
   291      // 1111____ ________ <four+ byte lead in byte 1>
   292      TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4,
   293    };
   294    __m256i byte_1_high = simd256_lookup16(prev1_shr4, tab1);
   295    
   296
   297    const uint8_t CARRY = TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 .
   298    __m256i prev1_low = _mm256_and_si256(prev1, _mm256_set1_epi8(0x0F));
   299    static const uint8_t tab2[16] = {
   300      // ____0000 ________
   301      CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
   302      // ____0001 ________
   303      CARRY | OVERLONG_2,
   304      // ____001_ ________
   305      CARRY,
   306      CARRY,
   307
   308      // ____0100 ________
   309      CARRY | TOO_LARGE,
   310      // ____0101 ________
   311      CARRY | TOO_LARGE | TOO_LARGE_1000,
   312      // ____011_ ________
   313      CARRY | TOO_LARGE | TOO_LARGE_1000,
   314      CARRY | TOO_LARGE | TOO_LARGE_1000,
   315
   316      // ____1___ ________
   317      CARRY | TOO_LARGE | TOO_LARGE_1000,
   318      CARRY | TOO_LARGE | TOO_LARGE_1000,
   319      CARRY | TOO_LARGE | TOO_LARGE_1000,
   320      CARRY | TOO_LARGE | TOO_LARGE_1000,
   321      CARRY | TOO_LARGE | TOO_LARGE_1000,
   322      // ____1101 ________
   323      CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
   324      CARRY | TOO_LARGE | TOO_LARGE_1000,
   325      CARRY | TOO_LARGE | TOO_LARGE_1000
   326    };
   327    __m256i byte_1_low = simd256_lookup16(prev1_low, tab2);
   328    
   329
   330    const __m256i input_shr4 = simd256_shr(input, 4);
   331    static const uint8_t tab3[16] = {
   332      // ________ 0_______ <ASCII in byte 2>
   333      TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
   334      TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
   335
   336      // ________ 1000____
   337      TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
   338      // ________ 1001____
   339      TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
   340      // ________ 101_____
   341      TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE  | TOO_LARGE,
   342      TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE  | TOO_LARGE,
   343
   344      // ________ 11______
   345      TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT
   346    };
   347    __m256i byte_2_high = simd256_lookup16(input_shr4, tab3);
   348     
   349
   350    return _mm256_and_si256(_mm256_and_si256(byte_1_high, byte_1_low), byte_2_high);
   351  }
   352
   353    static always_inline __m256i check_multibyte_lengths(const __m256i input, const __m256i prev_input, const __m256i sc) {
   354    __m256i prev2 = simd256_prev(input, prev_input, 2);
   355    __m256i prev3 = simd256_prev(input, prev_input, 3);
   356    
   357    
   358    __m256i must23 = must_be_2_3_continuation(prev2, prev3);
   359    
   360    __m256i must23_80 = _mm256_and_si256(must23, _mm256_set1_epi8(0x80));
   361    
   362    return _mm256_xor_si256(must23_80, sc);
   363  }
   364
   365
   366    // Check whether the current bytes are valid UTF-8.
   367    static always_inline __m256i check_utf8_bytes(const __m256i input, const __m256i prev_input) {
   368        // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or 4+ lead bytes
   369        // (2, 3, 4-byte leads become large positive numbers instead of small negative numbers)
   370        __m256i prev1 = simd256_prev(input, prev_input, 1);
   371        __m256i sc    = check_special_cases(input, prev1);
   372        __m256i ret  = check_multibyte_lengths(input, prev_input, sc);
   373        return ret;
   374    }
   375
   376    static always_inline bool is_ascii(const __m256i input) {
   377      return _mm256_movemask_epi8(input) == 0;
   378    }
   379
   380    typedef struct {
   381        // If this is nonzero, there has been a UTF-8 error.
   382        __m256i error;
   383        // The last input we received
   384        __m256i prev_input_block;
   385        // Whether the last input we received was incomplete (used for ASCII fast path)
   386        __m256i prev_incomplete;
   387    } utf8_checker;
   388
   389    static always_inline void utf8_checker_init(utf8_checker* checker) {
   390        checker->error = _mm256_setzero_si256();
   391        checker->prev_input_block = _mm256_setzero_si256();
   392        checker->prev_incomplete = _mm256_setzero_si256();
   393    }
   394    
   395    static always_inline bool check_error(utf8_checker* checker) {
   396        return !_mm256_testz_si256(checker->error, checker->error);
   397    }
   398
   399    static always_inline void check64_utf(utf8_checker* checker, const uint8_t* start) {
   400        __m256i input = _mm256_loadu_si256((__m256i*)start);
   401        __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
   402        // check utf-8 chars
   403        __m256i error1 = check_utf8_bytes(input, checker->prev_input_block);
   404        __m256i error2 = check_utf8_bytes(input2, input);
   405        checker->error = _mm256_or_si256(checker->error, _mm256_or_si256(error1, error2));
   406        checker->prev_input_block = input2;
   407        checker->prev_incomplete = is_incomplete(input2);
   408    }
   409
   410    static always_inline void check64(utf8_checker* checker, const uint8_t* start) {
   411        // fast path for contiguous ASCII
   412        __m256i input = _mm256_loadu_si256((__m256i*)start);
   413        __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
   414        __m256i reducer = _mm256_or_si256(input, input2);
   415        // check utf-8
   416        if (likely(is_ascii(reducer))) {
   417            checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   418            return;
   419        }
   420        check64_utf(checker, start);
   421    }
   422
   423    static always_inline void check128(utf8_checker* checker, const uint8_t* start) {
   424        // fast path for contiguous ASCII
   425        __m256i input = _mm256_loadu_si256((__m256i*)start);
   426        __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
   427        __m256i input3 = _mm256_loadu_si256((__m256i*)(start + 64));
   428        __m256i input4 = _mm256_loadu_si256((__m256i*)(start + 96));
   429        
   430        __m256i reducer1 = _mm256_or_si256(input, input2);
   431        __m256i reducer2 = _mm256_or_si256(input3, input4);
   432        __m256i reducer  = _mm256_or_si256(reducer1, reducer2);
   433
   434        // full 128 bytes are ascii
   435        if (likely(is_ascii(reducer))) {
   436            checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   437            return;
   438        }
   439
   440        // frist 64 bytes is ascii, next 64 bytes must be utf8
   441        if (likely(is_ascii(reducer1))) {
   442            checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   443            check64_utf(checker, start + 64);
   444            return;
   445        }
   446
   447        // frist 64 bytes has utf8, next 64 bytes 
   448        check64_utf(checker, start);
   449        if (unlikely(is_ascii(reducer2))) {
   450            checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   451        } else {
   452            check64_utf(checker, start + 64);
   453        }
   454    }
   455
   456    static always_inline void check_eof(utf8_checker* checker) {
   457        checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   458    }
   459
   460    static always_inline void check_remain(utf8_checker* checker, const uint8_t* start, const uint8_t* end) {
   461        uint8_t buffer[64] = {0};
   462        int i = 0;
   463        while (start < end) {
   464            buffer[i++] = *(start++);
   465        };
   466        check64(checker, buffer);
   467        check_eof(checker);
   468    }
   469
   470    static always_inline long validate_utf8_avx2(const GoString* s) {
   471        xassert(s->buf != NULL || s->len != 0);
   472        const uint8_t* start = (const uint8_t*)(s->buf);
   473        const uint8_t* end   = (const uint8_t*)(s->buf + s->len);
   474        /* check eof */
   475        if (s->len == 0) {
   476            return 0;
   477        }
   478        utf8_checker checker;
   479        utf8_checker_init(&checker);
   480        while (start < (end - 128)) {
   481            check128(&checker, start);
   482            if (check_error(&checker)) {
   483            }
   484            start += 128;
   485        };
   486        while (start < end - 64) {
   487            check64(&checker, start);
   488            start += 64;
   489        }
   490        check_remain(&checker, start, end);
   491        return check_error(&checker) ? -1 : 0;
   492    }
   493#endif

View as plain text