#include #include #include #include #include #include #include #include #include #include #include #include #include #include using json = nlohmann::json; #include "calcwit.hpp" #include "circom.hpp" #include "utils.hpp" Circom_Circuit *circuit; #define handle_error(msg) \ do { perror(msg); exit(EXIT_FAILURE); } while (0) #define SHMEM_WITNESS_KEY (123456) // assumptions // 1) There is only one key assigned for shared memory. This means // that only one witness can be computed and used at a time. If several witness // are computed before calling the prover, witness memory will be overwritten. // 2) Prover is responsible for releasing memory once is done with witness // // File format: // Type : 4B (wshm) // Version : 4B // N Section : 4B // HDR1 : 12B // N8 : 4B // Fr : N8 B // NVars : 4B // HDR2 : 12B // ShmemKey : 4B // Status : 4B (0:OK, 0xFFFF: KO) // ShmemID : 4B void writeOutShmem(Circom_CalcWit *ctx, std::string filename) { FILE *write_ptr; u64 *shbuf; int shmid, status = 0; write_ptr = fopen(filename.c_str(),"wb"); fwrite("wshm", 4, 1, write_ptr); u32 version = 2; fwrite(&version, 4, 1, write_ptr); u32 nSections = 2; fwrite(&nSections, 4, 1, write_ptr); // Header u32 idSection1 = 1; fwrite(&idSection1, 4, 1, write_ptr); u32 n8 = Fr_N64*8; u64 idSection1length = 8 + n8; fwrite(&idSection1length, 8, 1, write_ptr); fwrite(&n8, 4, 1, write_ptr); fwrite(Fr_q.longVal, Fr_N64*8, 1, write_ptr); u32 nVars = circuit->NVars; fwrite(&nVars, 4, 1, write_ptr); // Data u32 idSection2 = 2; fwrite(&idSection2, 4, 1, write_ptr); u64 idSection2length = n8*circuit->NVars; fwrite(&idSection2length, 8, 1, write_ptr); // generate key key_t key = SHMEM_WITNESS_KEY; fwrite(&key, sizeof(key_t), 1, write_ptr); // Setup shared memory if ((shmid = shmget(key, circuit->NVars * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0) { // preallocated shared memory segment is too small => Retrieve id by accesing old segment // Delete old segment and create new with corret size shmid = shmget(key, 4, IPC_CREAT | 0666); shmctl(shmid, IPC_RMID, NULL); if ((shmid = shmget(key, circuit->NVars * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0){ status = -1; fwrite(&status, sizeof(status), 1, write_ptr); fclose(write_ptr); return ; } } // Attach shared memory if ((shbuf = (u64 *)shmat(shmid, NULL, 0)) == (u64 *) -1) { status = -1; fwrite(&status, sizeof(status), 1, write_ptr); fclose(write_ptr); return; } fwrite(&status, sizeof(status), 1, write_ptr); fwrite(&shmid, sizeof(u32), 1, write_ptr); fclose(write_ptr); #pragma omp parallel for for (int i=0; iNVars;i++) { FrElement v; ctx->getWitness(i, &v); Fr_toLongNormal(&v, &v); memcpy(&shbuf[i*Fr_N64], v.longVal, Fr_N64*sizeof(u64)); } } void loadBin(Circom_CalcWit *ctx, std::string filename) { int fd; struct stat sb; // map input fd = open(filename.c_str(), O_RDONLY); if (fd == -1) handle_error("open"); if (fstat(fd, &sb) == -1) /* To obtain file size */ handle_error("fstat"); u8 *in; in = (u8 *)mmap(NULL, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0); if (in == MAP_FAILED) handle_error("mmap"); close(fd); FrElement v; u8 *p = in; for (int i=0; iNInputs; i++) { v.type = Fr_LONG; for (int j=0; jsetSignal(0, 0, circuit->wit2sig[1 + circuit->NOutputs + i], &v); } } typedef void (*ItFunc)(Circom_CalcWit *ctx, int idx, json val); void iterateArr(Circom_CalcWit *ctx, int o, Circom_Sizes sizes, json jarr, ItFunc f) { if (!jarr.is_array()) { assert((sizes[0] == 1)&&(sizes[1] == 0)); f(ctx, o, jarr); } else { int n = sizes[0] / sizes[1]; for (int i=0; i(); } else if (val.is_number()) { double vd = val.get(); std::stringstream stream; stream << std::fixed << std::setprecision(0) << vd; s = stream.str(); } else { handle_error("Invalid JSON type"); } Fr_str2element (&v, s.c_str()); ctx->setSignal(0, 0, o, &v); } void loadJson(Circom_CalcWit *ctx, std::string filename) { std::ifstream inStream(filename); json j; inStream >> j; u64 nItems = j.size(); printf("Items : %llu\n",nItems); for (json::iterator it = j.begin(); it != j.end(); ++it) { // std::cout << it.key() << " => " << it.value() << '\n'; u64 h = fnv1a(it.key()); int o; try { o = ctx->getSignalOffset(0, h); } catch (std::runtime_error e) { std::ostringstream errStrStream; errStrStream << "Error loadin variable: " << it.key() << "\n" << e.what(); throw std::runtime_error(errStrStream.str() ); } Circom_Sizes sizes = ctx->getSignalSizes(0, h); iterateArr(ctx, o, sizes, it.value(), itFunc); } } void writeOutBin(Circom_CalcWit *ctx, std::string filename) { FILE *write_ptr; write_ptr = fopen(filename.c_str(),"wb"); fwrite("wtns", 4, 1, write_ptr); u32 version = 2; fwrite(&version, 4, 1, write_ptr); u32 nSections = 2; fwrite(&nSections, 4, 1, write_ptr); // Header u32 idSection1 = 1; fwrite(&idSection1, 4, 1, write_ptr); u32 n8 = Fr_N64*8; u64 idSection1length = 8 + n8; fwrite(&idSection1length, 8, 1, write_ptr); fwrite(&n8, 4, 1, write_ptr); fwrite(Fr_q.longVal, Fr_N64*8, 1, write_ptr); u32 nVars = circuit->NVars; fwrite(&nVars, 4, 1, write_ptr); // Data u32 idSection2 = 2; fwrite(&idSection2, 4, 1, write_ptr); u64 idSection2length = (u64)n8*(u64)circuit->NVars; fwrite(&idSection2length, 8, 1, write_ptr); FrElement v; for (int i=0;iNVars;i++) { ctx->getWitness(i, &v); Fr_toLongNormal(&v, &v); fwrite(v.longVal, Fr_N64*8, 1, write_ptr); } fclose(write_ptr); } void writeOutJson(Circom_CalcWit *ctx, std::string filename) { std::ofstream outFile; outFile.open (filename); outFile << "[\n"; FrElement v; for (int i=0;iNVars;i++) { ctx->getWitness(i, &v); char *pcV = Fr_element2str(&v); std::string sV = std::string(pcV); outFile << (i ? "," : " ") << "\"" << sV << "\"\n"; free(pcV); } outFile << "]\n"; outFile.close(); } bool hasEnding (std::string const &fullString, std::string const &ending) { if (fullString.length() >= ending.length()) { return (0 == fullString.compare (fullString.length() - ending.length(), ending.length(), ending)); } else { return false; } } #define ADJ_P(a) *((void **)&a) = (void *)(((char *)circuit)+ (uint64_t)(a)) Circom_Circuit *loadCircuit(std::string const &datFileName) { Circom_Circuit *circuitF; Circom_Circuit *circuit; int fd; struct stat sb; fd = open(datFileName.c_str(), O_RDONLY); if (fd == -1) { std::cout << ".dat file not found: " << datFileName << "\n"; throw std::system_error(errno, std::generic_category(), "open"); } if (fstat(fd, &sb) == -1) { /* To obtain file size */ throw std::system_error(errno, std::generic_category(), "fstat"); } circuitF = (Circom_Circuit *)mmap(NULL, sb.st_size, PROT_READ , MAP_PRIVATE, fd, 0); close(fd); circuit = (Circom_Circuit *)malloc(sb.st_size); memcpy((void *)circuit, (void *)circuitF, sb.st_size); munmap(circuitF, sb.st_size); ADJ_P(circuit->wit2sig); ADJ_P(circuit->components); ADJ_P(circuit->mapIsInput); ADJ_P(circuit->constants); ADJ_P(circuit->P); ADJ_P(circuit->componentEntries); for (int i=0; iNComponents; i++) { ADJ_P(circuit->components[i].hashTable); ADJ_P(circuit->components[i].entries); circuit->components[i].fn = _functionTable[ (uint64_t)circuit->components[i].fn]; } for (int i=0; iNComponentEntries; i++) { ADJ_P(circuit->componentEntries[i].sizes); } return circuit; } int main(int argc, char *argv[]) { if (argc!=3) { std::string cl = argv[0]; std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1); std::cout << "Usage: " << base_filename << " > >\n"; } else { struct timeval begin, end; long seconds, microseconds; double elapsed; gettimeofday(&begin,0); std::string datFileName = argv[0]; datFileName += ".dat"; circuit = loadCircuit(datFileName); // open output Circom_CalcWit *ctx = new Circom_CalcWit(circuit); std::string infilename = argv[1]; gettimeofday(&end,0); seconds = end.tv_sec - begin.tv_sec; microseconds = end.tv_usec - begin.tv_usec; elapsed = seconds + microseconds*1e-6; printf("Up to loadJson %.20f\n", elapsed); if (hasEnding(infilename, std::string(".bin"))) { loadBin(ctx, infilename); } else if (hasEnding(infilename, std::string(".json"))) { loadJson(ctx, infilename); } else { handle_error("Invalid input extension (.bin / .json)"); } ctx->join(); // printf("Finished!\n"); std::string outfilename = argv[2]; if (hasEnding(outfilename, std::string(".wtns"))) { gettimeofday(&end,0); seconds = end.tv_sec - begin.tv_sec; microseconds = end.tv_usec - begin.tv_usec; elapsed = seconds + microseconds*1e-6; printf("Up to WriteWtns %.20f\n", elapsed); writeOutBin(ctx, outfilename); } else if (hasEnding(outfilename, std::string(".json"))) { writeOutJson(ctx, outfilename); } else if (hasEnding(outfilename, std::string(".wshm"))) { gettimeofday(&end,0); seconds = end.tv_sec - begin.tv_sec; microseconds = end.tv_usec - begin.tv_usec; elapsed = seconds + microseconds*1e-6; printf("Up to WriteShmem %.20f\n", elapsed); writeOutShmem(ctx, outfilename); } else { handle_error("Invalid output extension (.bin / .json)"); } delete ctx; gettimeofday(&end,0); seconds = end.tv_sec - begin.tv_sec; microseconds = end.tv_usec - begin.tv_usec; elapsed = seconds + microseconds*1e-6; printf("Total %.20f\n", elapsed); exit(EXIT_SUCCESS); } }