Safer Rust (catch panic with catch_unwind())

Crossing boundaries of multiple languages is tricky, but we can do at
least something about this, in particular, use catch_unwind() [1] to
catch possible panic!()s.

  [1]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html

Signed-off-by: Azat Khuzhin <a.khuzhin@semrush.com>
This commit is contained in:
Azat Khuzhin 2024-01-31 22:24:51 +01:00
parent 554bb5668e
commit 65cfbaaa4b
2 changed files with 40 additions and 3 deletions

View File

@ -2,6 +2,7 @@ use prql_compiler::sql::Dialect;
use prql_compiler::{Options, Target}; use prql_compiler::{Options, Target};
use std::ffi::{c_char, CString}; use std::ffi::{c_char, CString};
use std::slice; use std::slice;
use std::panic;
fn set_output(result: String, out: *mut *mut u8, out_size: *mut u64) { fn set_output(result: String, out: *mut *mut u8, out_size: *mut u64) {
assert!(!out_size.is_null()); assert!(!out_size.is_null());
@ -13,8 +14,7 @@ fn set_output(result: String, out: *mut *mut u8, out_size: *mut u64) {
*out_ptr = CString::new(result).unwrap().into_raw() as *mut u8; *out_ptr = CString::new(result).unwrap().into_raw() as *mut u8;
} }
#[no_mangle] pub unsafe extern "C" fn prql_to_sql_impl(
pub unsafe extern "C" fn prql_to_sql(
query: *const u8, query: *const u8,
size: u64, size: u64,
out: *mut *mut u8, out: *mut *mut u8,
@ -50,6 +50,23 @@ pub unsafe extern "C" fn prql_to_sql(
} }
} }
#[no_mangle]
pub unsafe extern "C" fn prql_to_sql(
query: *const u8,
size: u64,
out: *mut *mut u8,
out_size: *mut u64,
) -> i64 {
let ret = panic::catch_unwind(|| {
return prql_to_sql_impl(query, size, out, out_size);
});
return match ret {
// NOTE: using cxxbridge we can return proper Result<> type.
Err(_err) => 1,
Ok(res) => res,
}
}
#[no_mangle] #[no_mangle]
pub unsafe extern "C" fn prql_free_pointer(ptr_to_free: *mut u8) { pub unsafe extern "C" fn prql_free_pointer(ptr_to_free: *mut u8) {
std::mem::drop(CString::from_raw(ptr_to_free as *mut c_char)); std::mem::drop(CString::from_raw(ptr_to_free as *mut c_char));

View File

@ -1,6 +1,7 @@
use skim::prelude::*; use skim::prelude::*;
use term::terminfo::TermInfo; use term::terminfo::TermInfo;
use cxx::{CxxString, CxxVector}; use cxx::{CxxString, CxxVector};
use std::panic;
#[cxx::bridge] #[cxx::bridge]
mod ffi { mod ffi {
@ -36,7 +37,7 @@ impl SkimItem for Item {
} }
} }
fn skim(prefix: &CxxString, words: &CxxVector<CxxString>) -> Result<String, String> { fn skim_impl(prefix: &CxxString, words: &CxxVector<CxxString>) -> Result<String, String> {
// Let's check is terminal available. To avoid panic. // Let's check is terminal available. To avoid panic.
if let Err(err) = TermInfo::from_env() { if let Err(err) = TermInfo::from_env() {
return Err(format!("{}", err)); return Err(format!("{}", err));
@ -89,3 +90,22 @@ fn skim(prefix: &CxxString, words: &CxxVector<CxxString>) -> Result<String, Stri
} }
return Ok(output.selected_items[0].output().to_string()); return Ok(output.selected_items[0].output().to_string());
} }
fn skim(prefix: &CxxString, words: &CxxVector<CxxString>) -> Result<String, String> {
let ret = panic::catch_unwind(|| {
return skim_impl(prefix, words);
});
return match ret {
Err(err) => {
let e = if let Some(s) = err.downcast_ref::<String>() {
format!("{}", s)
} else if let Some(s) = err.downcast_ref::<&str>() {
format!("{}", s)
} else {
format!("Unknown panic type: {:?}", err.type_id())
};
Err(format!("Rust panic: {:?}", e))
},
Ok(res) => res,
}
}