Ticket #601: jmp.c

File jmp.c, 8.5 KB (added by JimCrayne, 7 years ago)

partial re-implementation of gmp

Line 
1#define SPOOF_GMP
2
3#include <stdlib.h>
4
5#include <stdio.h>
6#define dprintf printf
7
8#ifdef SPOOF_GMP
9#define jmpn_cmp __gmpn_cmp
10#define jmpz_divexact __gmpz_divexact
11#define jmpn_gcd_1 __gmpn_gcd_1
12#define jmpz_tdiv_qr __gmpz_tdiv_qr
13#define jmp_set_memory_functions __gmp_set_memory_functions
14#define jmpz_tdiv_q __gmpz_tdiv_q
15#define jmpz_tdiv_r __gmpz_tdiv_r
16#define jmpz_fdiv_qr __gmpz_fdiv_qr
17#define jmpz_add __gmpz_add
18#define jmpz_and __gmpz_and
19#define jmpz_com __gmpz_com
20#define jmpz_gcd __gmpz_gcd
21#define jmpz_ior __gmpz_ior
22#define jmpz_init __gmpz_init
23#define jmpz_mul __gmpz_mul
24#define jmpz_sub __gmpz_sub
25#define jmpz_xor __gmpz_xor
26#endif
27
28
29typedef unsigned int limb;
30
31typedef struct mpz {
32        int alloc; // zero indicates mp zero
33        int size;  // negative indicates mp negativity
34        limb *limbs;
35} mpz;
36
37typedef void *(*alloc_func ) (size_t); 
38typedef void *(*realloc_func ) (void *, size_t, size_t);
39typedef void (*free_func ) (void *, size_t);
40
41static void default_free( void *a, size_t s ) { free( a ); }
42static void *default_realloc( void *a, size_t olds, size_t news ) { return realloc( a, news ); }
43
44static alloc_func gAlloc = malloc;
45static realloc_func gRealloc = default_realloc;
46static free_func gFree = default_free;
47
48void jmp_set_memory_functions( alloc_func my_alloc, realloc_func my_realloc, free_func my_free )
49{
50        dprintf( "mp: set_memory_functions( %p, %p, %p)\n", my_alloc, my_realloc, my_free );
51        if( my_alloc ) gAlloc = my_alloc;
52        if( my_realloc ) gRealloc = my_realloc;
53        if( my_free ) gFree = my_free;
54}
55
56void jmpz_init( mpz *m )
57{
58        dprintf( "mp: mpz_init( %p )\n", m );
59        m->alloc = 1;
60        m->size = 0;
61        m->limbs = (limb*) gAlloc( sizeof(limb) );
62}
63
64void jmpz_add( mpz *retv, mpz *a, mpz *b )
65{
66        dprintf( "mp: mpz_add( %p, %p, %p )\n", retv, a, b );
67        int sa = a->size < 0 ? -1 : 1;
68        int sb = b->size < 0 ? -1 : 1;
69        int sx, sy;
70        int overflow;
71        int n = zmax( sa * a->size , sb * b->size ) + 1;
72        int m;
73        limb *x, *y, *z, *last_nzero;
74        if( retv->alloc < n ) retv->limbs = gRealloc( retv->limbs, sizeof(limb) * retv->alloc, sizeof(limb) * n );
75        overflow = 0;
76        x  = sa*a->size < sb*b->size ? a->limbs : b->limbs;
77        sx = sa*a->size < sb*b->size ? sa : sb;
78        y  = sa*a->size < sb*b->size ? b->limbs : a->limbs;
79        sy = sa*a->size < sb*b->size ? sb : sa;
80        z = retv->limbs;
81        m = zmin( sa * a->size, sb * b->size );
82        last_nzero = retv->limbs;
83        for( ; m; m-- ) {
84                *z = (unsigned int)(sx * *x) + (unsigned int)(sy * *y) + overflow;
85                overflow = *z < *x ? 1 : 0;
86                if( sx * sy < 0 ) overflow = 1 - overflow;
87                if( *z != 0 ) { last_nzero = z; }
88                z++;
89                x++;
90                y++;
91                n--;
92        }
93        for( ; n; n-- ) {
94                *z = (unsigned int)(sy * *y) + overflow;
95                overflow = *z < *? 1 : 0;
96                if( sy < 0 ) overflow = 1 - overflow;
97                if( *z != 0 ) { last_nzero = z; }
98                z++;
99                x++;
100                y++;
101        }
102       
103        retv->size = (last_nzero - retv->limbs)/sizeof(limb) + 1;
104
105}
106
107void jmpz_sub( mpz *retv, mpz *a, mpz *b )
108{
109        dprintf( "mp: mpz_sub( %p, %p, %p )\n", retv, a, b );
110        int sa = a->size < 0 ? -1 : 1;
111        int sb = b->size < 0 ? -1 : 1;
112        int sx, sy;
113        int overflow;
114        int n = zmax( sa * a->size , sb * b->size ) + 1;
115        int m;
116        limb *x, *y, *z, *last_nzero;
117        if( retv->alloc < n ) retv->limbs = gRealloc( retv->limbs, sizeof(limb) * retv->alloc, sizeof(limb) * n );
118        overflow = 0;
119        x  = sa*a->size < sb*b->size ? a->limbs : b->limbs;
120        sx = sa*a->size < sb*b->size ? sa : -sb;
121        y  = sa*a->size < sb*b->size ? b->limbs : a->limbs;
122        sy = sa*a->size < sb*b->size ? -sb : sa;
123        z = retv->limbs;
124        m = zmin( sa * a->size, sb * b->size );
125        last_nzero = retv->limbs;
126        for( ; m; m-- ) {
127                *z = (unsigned int)(sx * *x) + (unsigned int)(sy * *y) + overflow;
128                overflow = *z < *x ? 1 : 0;
129                if( sx * sy < 0 ) overflow = 1 - overflow;
130                if( *z != 0 ) { last_nzero = z; }
131                z++;
132                x++;
133                y++;
134                n--;
135        }
136        for( ; n; n-- ) {
137                *z = (unsigned int)(sy * *y) + overflow;
138                overflow = *z < *? 1 : 0;
139                if( sy < 0 ) overflow = 1 - overflow;
140                if( *z != 0 ) { last_nzero = z; }
141                z++;
142                x++;
143                y++;
144        }
145       
146        retv->size = (last_nzero - retv->limbs)/sizeof(limb) + 1;
147
148}
149
150// binary xor
151// TODO: What to do with negatives?
152void jmpz_xor( mpz *retv, mpz *a, mpz *b )
153{
154        dprintf( "mp: mpz_xor( %p, %p, %p )\n", retv, a, b );
155        int sa = a->size < 0 ? -1 : 1;
156        int sb = b->size < 0 ? -1 : 1;
157        int n = zmax( sa*a->size, sb*b->size );
158        int m = zmin( sa*a->size, sb*b->size );
159        limb *x, *y, *z, *last_nzero;
160        int sx, sy;
161        if( retv->alloc < n ) retv->limbs = gRealloc( retv->limbs, sizeof(limb)*retv->alloc, sizeof(limb)*n );
162        x  = sa*a->size < sb*b->size ? a->limbs : b->limbs;
163        sx = sa*a->size < sb*b->size ? sa : sb;
164        y  = sa*a->size < sb*b->size ? b->limbs : a->limbs;
165        sy = sa*a->size < sb*b->size ? sb : sa;
166        z = retv->limbs;
167        last_nzero = retv->limbs;
168        for( ;m; m-- ) {
169                *z = *x ^ *y;
170                if( *z != 0 ) { last_nzero = z; }
171                z++;
172                x++;
173                y++;
174                n--;
175        }
176        for( ;n; n-- ) {
177                *z = 0 ^ *y;
178                if( *z != 0 ) { last_nzero = z; }
179                z++;
180                x++;
181                y++;
182        }
183       
184        retv->size = (last_nzero - retv->limbs)/sizeof(limb) + 1;
185
186}
187
188// binary or
189// TODO: What to do with negatives?
190void jmpz_ior( mpz *retv, mpz *a, mpz *b )
191{
192        dprintf( "mp: mpz_xor( %p, %p, %p )\n", retv, a, b );
193        int sa = a->size < 0 ? -1 : 1;
194        int sb = b->size < 0 ? -1 : 1;
195        int n = zmax( sa*a->size, sb*b->size );
196        int m = zmin( sa*a->size, sb*b->size );
197        limb *x, *y, *z, *last_nzero;
198        int sx, sy;
199        if( retv->alloc < n ) retv->limbs = gRealloc( retv->limbs, sizeof(limb)*retv->alloc, sizeof(limb)*n );
200        x  = sa*a->size < sb*b->size ? a->limbs : b->limbs;
201        sx = sa*a->size < sb*b->size ? sa : sb;
202        y  = sa*a->size < sb*b->size ? b->limbs : a->limbs;
203        sy = sa*a->size < sb*b->size ? sb : sa;
204        z = retv->limbs;
205        last_nzero = retv->limbs;
206        for( ;m; m-- ) {
207                *z = *x | *y;
208                if( *z != 0 ) { last_nzero = z; }
209                z++;
210                x++;
211                y++;
212                n--;
213        }
214        for( ;n; n-- ) {
215                *z = 0 | *y;
216                if( *z != 0 ) { last_nzero = z; }
217                z++;
218                x++;
219                y++;
220        }
221       
222        retv->size = (last_nzero - retv->limbs)/sizeof(limb) + 1;
223
224}
225
226// binary or
227// TODO: What to do with negatives?
228void jmpz_and( mpz *retv, mpz *a, mpz *b )
229{
230        dprintf( "mp: mpz_xor( %p, %p, %p )\n", retv, a, b );
231        int sa = a->size < 0 ? -1 : 1;
232        int sb = b->size < 0 ? -1 : 1;
233        int n = zmax( sa*a->size, sb*b->size );
234        int m = zmin( sa*a->size, sb*b->size );
235        limb *x, *y, *z, *last_nzero;
236        int sx, sy;
237        if( retv->alloc < n ) retv->limbs = gRealloc( retv->limbs, sizeof(limb)*retv->alloc, sizeof(limb)*n );
238        x  = sa*a->size < sb*b->size ? a->limbs : b->limbs;
239        sx = sa*a->size < sb*b->size ? sa : sb;
240        y  = sa*a->size < sb*b->size ? b->limbs : a->limbs;
241        sy = sa*a->size < sb*b->size ? sb : sa;
242        z = retv->limbs;
243        last_nzero = retv->limbs;
244        for( ;m; m-- ) {
245                *z = *x & *y;
246                if( *z != 0 ) { last_nzero = z; }
247                z++;
248                x++;
249                y++;
250                n--;
251        }
252        for( ;n; n-- ) {
253                *z = 0 & *y;
254                if( *z != 0 ) { last_nzero = z; }
255                z++;
256                x++;
257                y++;
258        }
259       
260        retv->size = (last_nzero - retv->limbs)/sizeof(limb) + 1;
261
262}
263
264// binary complement
265// TODO: What to do with negatives?
266void jmpz_com( mpz *retv, mpz *a )
267{
268        dprintf( "mp: mpz_com( %p, %p )\n", retv, a );
269        int sa = a->size < 0 ? -1 : 1;
270        int n = sa * a->size;
271        int i;
272        if( retv->alloc < n ) retv->limbs = gRealloc( retv->limbs, sizeof(limb)*retv->alloc, sizeof(limb)*n );
273        for( i=0; i< n; i++ ) {
274                retv->limbs[i] = ~ a->limbs[i];
275        }
276        retv->size = a->size;
277}
278
279/*
280int jmpz_cmp( mpz *a, mpz *b )
281{
282        int sa = a->size < 0 ? -1 : 1;
283        int sb = b->size < 0 ? -1 : 1;
284        if( sa < sb ) return -1;
285        if( sa > sb ) return 1;
286
287        limb r = 0;
288        int n = sa * a->size;
289        int i;
290        for( i=n-1 ; i>=0; i-- ) {
291                r = a->limbs[i]  - b->limbs[i];
292                if( r ) break;
293        }
294        return sa * r;
295}
296*/
297
298int jmpn_cmp (limb *a, limb *b, size_t n )
299{
300        dprintf( "mp: mpn_cmp(%p,%p, %zu)\n", a, b, n );
301        limb r = 0;
302        int i;
303        for( i=n-1 ; i>=0; i-- ) {
304                r = a[i]  - b[i];
305                if( r ) break;
306        }
307        return (int) r;
308}
309
310
311void jmpz_set( mpz *t, mpz *b )
312{
313        t->alloc = b->alloc;
314        t->size = b->size;
315        t->limbs = gAlloc( t->alloc );
316        memcpy( t->limbs, b->limbs, t->alloc );
317}
318
319// let a = b/c  where c divides b, but this algorithm is ridiculously slow (it repeatedly adds one to a and subtracts c from b)
320int jmpz_divexact( mpz *a, mpz *b, mpz *c )
321{
322        mpz one;
323        mpz t;
324        int sgn_b, sgn_c;
325        sgn_b = b->size > 0 ? 1 : -1;
326        sgn_c = c->size > 0 ? 1 : -1;
327
328        jmpz_init( &one )
329        one.size = 1;
330        one.limbs[0] = 1;
331
332        jmpz_set( &t, b );
333        t.size *= sgn_b;
334
335        // Set a = 0
336        a->limbs = gRealloc( a->limbs, a->alloc, 1 );
337        a->size = 0;
338        a->alloc = 1;
339
340        if( c->size == 0 ) { 
341                // TODO: division by zero!
342                return;
343        }
344        c->size *= sgn_c;
345
346        while( t->size>0 ) {
347                mpz tmp;
348                jmpz_set( &tmp, a );
349                jmpz_add( a, &tmp, &one );
350                jmpz_set( &tmp, &t );
351                jmpz_sub( &t, &t, c );
352        }
353
354        c->size *= sgn_c;
355
356}
357
358/*
359TODO:
360        ___gmpn_gcd_1
361        ___gmpz_tdiv_qr
362        ___gmpz_tdiv_q
363        ___gmpz_tdiv_r
364        ___gmpz_fdiv_qr
365        ___gmpz_gcd
366        ___gmpz_mul
367*/