# # # patch "database.ml" # from [6675f62b995a0380b329bdd5c9134023df4416ec] # to [d54e4865a25cf71403c8fc777a4aa377dcad1ca6] # # patch "mlsqlite/sqlite3.ml" # from [fecc3f3795b4baeb75e1816275cfcd7c96c4b8c6] # to [b0fdf5dd266be78db186bd51e9f15cc789f0954a] # # patch "mlsqlite/sqlite3.mli" # from [61a3daf169c5cfa3c693dcce707a3489b69c420b] # to [b08f8df3b47830a8a9a19d7d1a308aceb28ef281] # ============================================================ --- database.ml 6675f62b995a0380b329bdd5c9134023df4416ec +++ database.ml d54e4865a25cf71403c8fc777a4aa377dcad1ca6 @@ -21,7 +21,7 @@ let setup_sqlite ?busy_handler db = then Sqlite3.trace_set db (fun s -> prerr_string "### sql: " ; prerr_endline s) ; - Sqlite3.exec db "PRAGMA temp_store = MEMORY" [] ; + Sqlite3.exec db "PRAGMA temp_store = MEMORY" ; may (Sqlite3.busy_set db) busy_handler @@ -34,7 +34,7 @@ let schema_id db = WHERE (type = 'table' OR type = 'index') \ AND sql IS NOT NULL \ AND name NOT LIKE 'sqlite_stat%' \ - ORDER BY name" [] + ORDER BY name" (acc_one_col false) [] in let schema_data = String.concat "\n" (List.rev lines) in let schema = Schema_lexer.massage_sql_tokens schema_data in @@ -42,7 +42,7 @@ let has_rosters db = let has_rosters db = Sqlite3.fetch db - "SELECT name FROM sqlite_master WHERE name = 'rosters'" [] + "SELECT name FROM sqlite_master WHERE name = 'rosters'" (fun _ _ -> true) false @@ -51,7 +51,7 @@ let fetch_pubkeys db base64 tbl = let fetch_pubkeys db base64 tbl = Sqlite3.fetch db - "SELECT id, keydata, ROWID FROM public_keys" [] + "SELECT id, keydata, ROWID FROM public_keys" (fun () stmt -> let id = Sqlite3.column_text stmt 0 in let data = blob_col base64 stmt 1 in @@ -65,7 +65,7 @@ let fetch_branches base64 db = let fetch_branches base64 db = List.sort compare (Sqlite3.fetch db - "SELECT DISTINCT value FROM revision_certs WHERE name = 'branch'" [] + "SELECT DISTINCT value FROM revision_certs WHERE name = 'branch'" (acc_one_col base64) []) @@ -139,7 +139,7 @@ let collect_tags db base64 view g = let collect_tags db base64 view g = Sqlite3.fetch_f db "SELECT C.id, C.value FROM revision_certs AS C, %s AS D WHERE name = 'tag' AND C.id = D.id" - view [] + view (fun () stmt -> let id = Sqlite3.column_text stmt 0 in let n = NodeMap.find id g.nodes in @@ -203,7 +203,7 @@ let fetch_agraph_with_view db base64 (qu (* grab all our main nodes *) let agraph = Sqlite3.fetch_f db - "SELECT id FROM %s" view_name_limit [] + "SELECT id FROM %s" view_name_limit process_regular_node agraph in (* neighbor IN *) @@ -212,7 +212,7 @@ let fetch_agraph_with_view db base64 (qu "SELECT parent, child \ FROM %s AS D, revision_ancestry AS A \ WHERE D.id = A.child AND A.parent != '' AND A.parent NOT IN %s" - view_name_limit view_name_domain [] + view_name_limit view_name_domain process_neighb_in agraph in (* neighbor OUT *) @@ -221,7 +221,7 @@ let fetch_agraph_with_view db base64 (qu "SELECT parent, child \ FROM %s AS D, revision_ancestry AS A \ WHERE D.id = A.parent AND A.child NOT IN %s" - view_name_limit view_name_domain [] + view_name_limit view_name_domain (process_neighb_out db) agraph in (* ancestry *) @@ -230,7 +230,7 @@ let fetch_agraph_with_view db base64 (qu "SELECT parent, child \ FROM %s AS D1, revision_ancestry AS A, %s AS D2 \ WHERE D1.id = A.parent AND A.child = D2.id" - view_name_limit view_name_limit [] + view_name_limit view_name_limit process_ancestry agraph in (* find merge/propagate nodes (they have more than one parent) *) @@ -258,7 +258,7 @@ let fetch_agraph_with_view db base64 (qu WHERE C.id = A.child AND P.id = A.parent \ AND C.name = 'branch' AND P.name = 'branch' \ AND C.value = P.value)" - view_name_limit [] + view_name_limit process_branching_edge agraph end in @@ -315,24 +315,24 @@ let fetch_with_view query base64 db f = ~before:(fun () -> (* We fetch the ids matching the query (ie those on certain branches) *) (* and store them in a view. *) - Sqlite3.exec db view_query_domain [] ; + Sqlite3.exec db view_query_domain ; Sqlite3.exec_f db - "CREATE INDEX %s__id ON %s (id)" view_name_domain view_name_domain [] ; + "CREATE INDEX %s__id ON %s (id)" view_name_domain view_name_domain ; if query_limit <> QUERY_NO_LIMIT then begin register_date_p () ; - Sqlite3.exec db (view_query_date_limit ()) [] ; + Sqlite3.exec db (view_query_date_limit ()) ; Sqlite3.exec_f db - "CREATE INDEX %s__id ON %s (id)" view_name_limit view_name_limit [] + "CREATE INDEX %s__id ON %s (id)" view_name_limit view_name_limit end) ~action:(fun () -> f db base64 query) ~after:(fun () -> if query_limit <> QUERY_NO_LIMIT then begin Sqlite3.delete_function db "date_p" ; - Sqlite3.exec_f db "DROP TABLE %s" view_name_limit [] + Sqlite3.exec_f db "DROP TABLE %s" view_name_limit end ; - Sqlite3.exec_f db "DROP TABLE %s" view_name_domain []) + Sqlite3.exec_f db "DROP TABLE %s" view_name_domain) () let fetch_agraph query base64 db = @@ -356,7 +356,7 @@ let fetch_revision_set rostered b64 db i decode_and_parse_revision rostered (List.hd - (Sqlite3.fetch db + (Sqlite3.fetch_v db "SELECT data FROM revisions WHERE id = ?" [`TEXT id] (acc_one_col b64) [])) @@ -383,7 +383,7 @@ let fetch_certs db pubkeys b64 id = c_signature = verify_cert_sig pubkeys keypair name id dec_v dec_sig } :: acc let fetch_certs db pubkeys b64 id = - Sqlite3.fetch db + Sqlite3.fetch_v db "SELECT id, name, value, keypair, signature \ FROM revision_certs WHERE id = ?" [`TEXT id] (process_certs pubkeys b64) [] @@ -408,7 +408,7 @@ let get_matching_cert db b64 name p = let get_matching_cert db b64 name p = List.rev - (Sqlite3.fetch db + (Sqlite3.fetch_v db "SELECT id, value FROM revision_certs WHERE name = ?" [`TEXT name] (fun acc s -> ============================================================ --- mlsqlite/sqlite3.ml fecc3f3795b4baeb75e1816275cfcd7c96c4b8c6 +++ mlsqlite/sqlite3.ml b0fdf5dd266be78db186bd51e9f15cc789f0954a @@ -284,10 +284,10 @@ let rec do_step stmt = | `DONE -> () | `ROW -> do_step stmt -let _exec db sql data = - _fold_prepare_bind +let _exec db sql = + _fold_prepare ~final:true - db sql data + db sql (fun () stmt -> do_step stmt) () @@ -297,7 +297,21 @@ let exec_f db fmt = let exec_f db fmt = Printf.kprintf (_exec db) fmt +let _exec_v db sql data = + _fold_prepare_bind + ~final:true + db sql data + (fun () stmt -> do_step stmt) + () +let exec_v db sql = + _exec_v db (String.copy sql) + +let exec_fv db fmt = + Printf.kprintf (_exec_v db) fmt + + + (* Execute statements and get some results back *) let rec fold_step f acc stmt = match step stmt with @@ -305,9 +319,9 @@ let rec fold_step f acc stmt = | `ROW -> fold_step f (f acc stmt) stmt -let _fetch db sql data f init = - _fold_prepare_bind - db sql data +let _fetch db sql f init = + _fold_prepare + db sql (fold_step f) init let fetch db sql = @@ -316,7 +330,19 @@ let fetch_f db fmt = let fetch_f db fmt = Printf.kprintf (_fetch db) fmt +let _fetch_v db sql data f init = + _fold_prepare_bind + db sql data + (fold_step f) init +let fetch_v db sql = + _fetch_v db (String.copy sql) + +let fetch_fv db fmt = + Printf.kprintf (_fetch_v db) fmt + + + (* Reset-Bind-Step *) let bind_and_exec stmt bindings = reset stmt ; @@ -327,3 +353,58 @@ let bind_fetch stmt bindings f init = reset stmt ; ignore (do_bind stmt bindings) ; fold_step f init stmt + +let sql_escape s = + let n = ref 0 in + let len = String.length s in + for i = 0 to len - 1 do + let c = String.unsafe_get s i in + if c = '\'' then incr n + done ; + if !n = 0 + then s + else begin + let n_len = len + !n in + let o = String.create n_len in + let j = ref 0 in + for i = 0 to len - 1 do + let c = String.unsafe_get s i in + if c = '\'' then begin + String.unsafe_set o !j '\'' ; + incr j + end ; + String.unsafe_set o !j c ; + incr j + done ; + assert (!j = n_len) ; + o + end + +let char_of_hex v = + if v < 0xa + then Char.chr (v + Char.code '0') + else Char.chr (v - 0xa + Char.code 'a') + +let hex_enc s = + let len = String.length s in + let o = String.create (2 * len) in + for i = 0 to len - 1 do + let c = int_of_char s.[i] in + let hi = c lsr 4 in + o.[2*i] <- char_of_hex hi ; + let lo = c land 0xf in + o.[2*i + 1] <- char_of_hex lo + done ; + o + +let blob_escape = hex_enc + +let string_of_transaction = function + | `DEFERRED -> "DEFERRED" + | `IMMEDIATE -> "IMMEDIATE" + | `EXCLUSIVE -> "EXCLUSIVE" + +let transaction ?(kind=`DEFERRED) db f = + exec db ("BEGIN " ^ string_of_transaction kind) ; + try let r = f db in exec db "COMMIT" ; r + with exn -> exec db "ROLLBACK" ; raise exn ============================================================ --- mlsqlite/sqlite3.mli 61a3daf169c5cfa3c693dcce707a3489b69c420b +++ mlsqlite/sqlite3.mli b08f8df3b47830a8a9a19d7d1a308aceb28ef281 @@ -159,13 +159,29 @@ val bind_fetch : stmt -> sql_value li val bind_and_exec : stmt -> sql_value list -> unit val bind_fetch : stmt -> sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a -val fetch : db -> string -> sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a -val exec : db -> string -> sql_value list -> unit -val fold_prepare_bind : db -> string -> sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a -val fold_prepare : db -> string -> ('a -> stmt -> 'a) -> 'a -> 'a +val fold_prepare : db -> string -> ('a -> stmt -> 'a) -> 'a -> 'a +val fold_prepare_bind : db -> string -> sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a -val fetch_f : db -> (sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a, 'b) fmt -val exec_f : db -> (sql_value list -> unit, 'b) fmt +val fetch : db -> string -> ('a -> stmt -> 'a) -> 'a -> 'a +val exec : db -> string -> unit + +val fetch_v : db -> string -> sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a +val exec_v : db -> string -> sql_value list -> unit + +val fold_prepare_f : db -> (('a -> stmt -> 'a) -> 'a -> 'a, 'b) fmt val fold_prepare_bind_f : db -> (sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a, 'b) fmt -val fold_prepare_f : db -> (('a -> stmt -> 'a) -> 'a -> 'a, 'b) fmt +val fetch_f : db -> (('a -> stmt -> 'a) -> 'a -> 'a, 'b) fmt +val exec_f : db -> (unit, 'b) fmt + +val fetch_fv : db -> (sql_value list -> ('a -> stmt -> 'a) -> 'a -> 'a, 'b) fmt +val exec_fv : db -> (sql_value list -> unit, 'b) fmt + +(** {2 Convenience functions} *) + +val sql_escape : string -> string +val blob_escape : string -> string + +val transaction : + ?kind:[`DEFERRED|`IMMEDIATE|`EXCLUSIVE] -> + db -> (db -> 'a) -> 'a