1
2
3
4
5 package bn256
6
7 import (
8 "math/big"
9 )
10
11
12
13
14 type curvePoint struct {
15 x, y, z, t *big.Int
16 }
17
18 var curveB = new(big.Int).SetInt64(3)
19
20
21 var curveGen = &curvePoint{
22 new(big.Int).SetInt64(1),
23 new(big.Int).SetInt64(-2),
24 new(big.Int).SetInt64(1),
25 new(big.Int).SetInt64(1),
26 }
27
28 func newCurvePoint(pool *bnPool) *curvePoint {
29 return &curvePoint{
30 pool.Get(),
31 pool.Get(),
32 pool.Get(),
33 pool.Get(),
34 }
35 }
36
37 func (c *curvePoint) String() string {
38 c.MakeAffine(new(bnPool))
39 return "(" + c.x.String() + ", " + c.y.String() + ")"
40 }
41
42 func (c *curvePoint) Put(pool *bnPool) {
43 pool.Put(c.x)
44 pool.Put(c.y)
45 pool.Put(c.z)
46 pool.Put(c.t)
47 }
48
49 func (c *curvePoint) Set(a *curvePoint) {
50 c.x.Set(a.x)
51 c.y.Set(a.y)
52 c.z.Set(a.z)
53 c.t.Set(a.t)
54 }
55
56
57 func (c *curvePoint) IsOnCurve() bool {
58 yy := new(big.Int).Mul(c.y, c.y)
59 xxx := new(big.Int).Mul(c.x, c.x)
60 xxx.Mul(xxx, c.x)
61 yy.Sub(yy, xxx)
62 yy.Sub(yy, curveB)
63 if yy.Sign() < 0 || yy.Cmp(p) >= 0 {
64 yy.Mod(yy, p)
65 }
66 return yy.Sign() == 0
67 }
68
69 func (c *curvePoint) SetInfinity() {
70 c.z.SetInt64(0)
71 }
72
73 func (c *curvePoint) IsInfinity() bool {
74 return c.z.Sign() == 0
75 }
76
77 func (c *curvePoint) Add(a, b *curvePoint, pool *bnPool) {
78 if a.IsInfinity() {
79 c.Set(b)
80 return
81 }
82 if b.IsInfinity() {
83 c.Set(a)
84 return
85 }
86
87
88
89
90
91
92 z1z1 := pool.Get().Mul(a.z, a.z)
93 z1z1.Mod(z1z1, p)
94 z2z2 := pool.Get().Mul(b.z, b.z)
95 z2z2.Mod(z2z2, p)
96 u1 := pool.Get().Mul(a.x, z2z2)
97 u1.Mod(u1, p)
98 u2 := pool.Get().Mul(b.x, z1z1)
99 u2.Mod(u2, p)
100
101 t := pool.Get().Mul(b.z, z2z2)
102 t.Mod(t, p)
103 s1 := pool.Get().Mul(a.y, t)
104 s1.Mod(s1, p)
105
106 t.Mul(a.z, z1z1)
107 t.Mod(t, p)
108 s2 := pool.Get().Mul(b.y, t)
109 s2.Mod(s2, p)
110
111
112
113
114
115
116
117
118 h := pool.Get().Sub(u2, u1)
119 xEqual := h.Sign() == 0
120
121 t.Add(h, h)
122
123 i := pool.Get().Mul(t, t)
124 i.Mod(i, p)
125
126 j := pool.Get().Mul(h, i)
127 j.Mod(j, p)
128
129 t.Sub(s2, s1)
130 yEqual := t.Sign() == 0
131 if xEqual && yEqual {
132 c.Double(a, pool)
133 return
134 }
135 r := pool.Get().Add(t, t)
136
137 v := pool.Get().Mul(u1, i)
138 v.Mod(v, p)
139
140
141 t4 := pool.Get().Mul(r, r)
142 t4.Mod(t4, p)
143 t.Add(v, v)
144 t6 := pool.Get().Sub(t4, j)
145 c.x.Sub(t6, t)
146
147
148
149
150 t.Sub(v, c.x)
151 t4.Mul(s1, j)
152 t4.Mod(t4, p)
153 t6.Add(t4, t4)
154 t4.Mul(r, t)
155 t4.Mod(t4, p)
156 c.y.Sub(t4, t6)
157
158
159 t.Add(a.z, b.z)
160 t4.Mul(t, t)
161 t4.Mod(t4, p)
162 t.Sub(t4, z1z1)
163 t4.Sub(t, z2z2)
164 c.z.Mul(t4, h)
165 c.z.Mod(c.z, p)
166
167 pool.Put(z1z1)
168 pool.Put(z2z2)
169 pool.Put(u1)
170 pool.Put(u2)
171 pool.Put(t)
172 pool.Put(s1)
173 pool.Put(s2)
174 pool.Put(h)
175 pool.Put(i)
176 pool.Put(j)
177 pool.Put(r)
178 pool.Put(v)
179 pool.Put(t4)
180 pool.Put(t6)
181 }
182
183 func (c *curvePoint) Double(a *curvePoint, pool *bnPool) {
184
185 A := pool.Get().Mul(a.x, a.x)
186 A.Mod(A, p)
187 B := pool.Get().Mul(a.y, a.y)
188 B.Mod(B, p)
189 C := pool.Get().Mul(B, B)
190 C.Mod(C, p)
191
192 t := pool.Get().Add(a.x, B)
193 t2 := pool.Get().Mul(t, t)
194 t2.Mod(t2, p)
195 t.Sub(t2, A)
196 t2.Sub(t, C)
197 d := pool.Get().Add(t2, t2)
198 t.Add(A, A)
199 e := pool.Get().Add(t, A)
200 f := pool.Get().Mul(e, e)
201 f.Mod(f, p)
202
203 t.Add(d, d)
204 c.x.Sub(f, t)
205
206 t.Add(C, C)
207 t2.Add(t, t)
208 t.Add(t2, t2)
209 c.y.Sub(d, c.x)
210 t2.Mul(e, c.y)
211 t2.Mod(t2, p)
212 c.y.Sub(t2, t)
213
214 t.Mul(a.y, a.z)
215 t.Mod(t, p)
216 c.z.Add(t, t)
217
218 pool.Put(A)
219 pool.Put(B)
220 pool.Put(C)
221 pool.Put(t)
222 pool.Put(t2)
223 pool.Put(d)
224 pool.Put(e)
225 pool.Put(f)
226 }
227
228 func (c *curvePoint) Mul(a *curvePoint, scalar *big.Int, pool *bnPool) *curvePoint {
229 sum := newCurvePoint(pool)
230 sum.SetInfinity()
231 t := newCurvePoint(pool)
232
233 for i := scalar.BitLen(); i >= 0; i-- {
234 t.Double(sum, pool)
235 if scalar.Bit(i) != 0 {
236 sum.Add(t, a, pool)
237 } else {
238 sum.Set(t)
239 }
240 }
241
242 c.Set(sum)
243 sum.Put(pool)
244 t.Put(pool)
245 return c
246 }
247
248
249
250 func (c *curvePoint) MakeAffine(pool *bnPool) *curvePoint {
251 if words := c.z.Bits(); len(words) == 1 && words[0] == 1 {
252 return c
253 }
254 if c.IsInfinity() {
255 c.x.SetInt64(0)
256 c.y.SetInt64(1)
257 c.z.SetInt64(0)
258 c.t.SetInt64(0)
259 return c
260 }
261
262 zInv := pool.Get().ModInverse(c.z, p)
263 t := pool.Get().Mul(c.y, zInv)
264 t.Mod(t, p)
265 zInv2 := pool.Get().Mul(zInv, zInv)
266 zInv2.Mod(zInv2, p)
267 c.y.Mul(t, zInv2)
268 c.y.Mod(c.y, p)
269 t.Mul(c.x, zInv2)
270 t.Mod(t, p)
271 c.x.Set(t)
272 c.z.SetInt64(1)
273 c.t.SetInt64(1)
274
275 pool.Put(zInv)
276 pool.Put(t)
277 pool.Put(zInv2)
278
279 return c
280 }
281
282 func (c *curvePoint) Negative(a *curvePoint) {
283 c.x.Set(a.x)
284 c.y.Neg(a.y)
285 c.z.Set(a.z)
286 c.t.SetInt64(0)
287 }
288
View as plain text