use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use std::result::Result as StdResult;
use common_recordbatch::RecordBatch;
use common_telemetry::info;
use datatypes::vectors::VectorRef;
use rustpython_vm::builtins::{PyBaseExceptionRef, PyDict, PyStr, PyTuple};
use rustpython_vm::class::PyClassImpl;
use rustpython_vm::convert::ToPyObject;
use rustpython_vm::scope::Scope;
use rustpython_vm::{vm, AsObject, Interpreter, PyObjectRef, PyPayload, VirtualMachine};
use snafu::{OptionExt, ResultExt};
use crate::engine::EvalContext;
use crate::python::error::{ensure, ret_other_error_with, NewRecordBatchSnafu, OtherSnafu, Result};
use crate::python::ffi_types::copr::PyQueryEngine;
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
use crate::python::metric;
use crate::python::rspython::builtins::init_greptime_builtins;
use crate::python::rspython::dataframe_impl::data_frame::set_dataframe_in_scope;
use crate::python::rspython::dataframe_impl::init_data_frame;
use crate::python::rspython::utils::{format_py_error, is_instance, py_obj_to_vec};
thread_local!(static INTERPRETER: RefCell<Option<Rc<Interpreter>>> = const { RefCell::new(None) });
pub(crate) fn rspy_exec_parsed(
copr: &Coprocessor,
rb: &Option<RecordBatch>,
params: &HashMap<String, String>,
eval_ctx: &EvalContext,
) -> Result<RecordBatch> {
let _t = metric::METRIC_RSPY_EXEC_TOTAL_ELAPSED.start_timer();
let args: Vec<PyVector> = if let Some(rb) = rb {
let arg_names = copr.deco_args.arg_names.clone().unwrap_or_default();
let args = select_from_rb(rb, &arg_names)?;
check_args_anno_real_type(&arg_names, &args, copr, rb)?;
args
} else {
vec![]
};
let interpreter = init_interpreter();
exec_with_cached_vm(copr, rb, args, params, &interpreter, eval_ctx)
}
fn set_items_in_scope(
scope: &Scope,
vm: &VirtualMachine,
arg_names: &[String],
args: Vec<PyVector>,
) -> Result<()> {
let _ = arg_names
.iter()
.zip(args)
.map(|(name, vector)| {
scope
.locals
.as_object()
.set_item(name, vm.new_pyobj(vector), vm)
})
.collect::<StdResult<Vec<()>, PyBaseExceptionRef>>()
.map_err(|e| format_py_error(e, vm))?;
Ok(())
}
fn set_query_engine_in_scope(
scope: &Scope,
vm: &VirtualMachine,
name: &str,
query_engine: PyQueryEngine,
) -> Result<()> {
scope
.locals
.as_object()
.set_item(name, query_engine.to_pyobject(vm), vm)
.map_err(|e| format_py_error(e, vm))
}
pub(crate) fn exec_with_cached_vm(
copr: &Coprocessor,
rb: &Option<RecordBatch>,
args: Vec<PyVector>,
params: &HashMap<String, String>,
vm: &Rc<Interpreter>,
eval_ctx: &EvalContext,
) -> Result<RecordBatch> {
vm.enter(|vm| -> Result<RecordBatch> {
let _t = metric::METRIC_RSPY_EXEC_ELAPSED.start_timer();
let scope = vm.new_scope_with_builtins();
if let Some(rb) = rb {
set_dataframe_in_scope(&scope, vm, "__dataframe__", rb)?;
}
if let Some(arg_names) = &copr.deco_args.arg_names {
assert_eq!(arg_names.len(), args.len());
set_items_in_scope(&scope, vm, arg_names, args)?;
}
if let Some(engine) = &copr.query_engine {
let query_engine =
PyQueryEngine::from_weakref(engine.clone(), eval_ctx.query_ctx.clone());
set_query_engine_in_scope(&scope, vm, "__query__", query_engine)?;
}
if let Some(kwarg) = &copr.kwarg {
let dict = PyDict::new_ref(&vm.ctx);
for (k, v) in params {
dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm)
.map_err(|e| format_py_error(e, vm))?;
}
scope
.locals
.as_object()
.set_item(kwarg, vm.new_pyobj(dict), vm)
.map_err(|e| format_py_error(e, vm))?;
}
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
let ret = vm
.run_code_obj(code_obj, scope)
.map_err(|e| format_py_error(e, vm))?;
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
let mut cols = try_into_columns(&ret, vm, col_len)?;
ensure!(
cols.len() == copr.deco_args.ret_names.len(),
OtherSnafu {
reason: format!(
"The number of return Vector is wrong, expect {}, found {}",
copr.deco_args.ret_names.len(),
cols.len()
)
}
);
copr.check_and_cast_type(&mut cols)?;
let schema = copr.gen_schema(&cols)?;
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
})
}
fn try_into_columns(
obj: &PyObjectRef,
vm: &VirtualMachine,
col_len: usize,
) -> Result<Vec<VectorRef>> {
if is_instance::<PyTuple>(obj, vm) {
let tuple = obj
.payload::<PyTuple>()
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyTuple)")))?;
let cols = tuple
.iter()
.map(|obj| py_obj_to_vec(obj, vm, col_len))
.collect::<Result<Vec<VectorRef>>>()?;
Ok(cols)
} else {
let col = py_obj_to_vec(obj, vm, col_len)?;
Ok(vec![col])
}
}
pub(crate) fn init_interpreter() -> Rc<Interpreter> {
let _t = metric::METRIC_RSPY_INIT_ELAPSED.start_timer();
INTERPRETER.with(|i| {
i.borrow_mut()
.get_or_insert_with(|| {
let native_module_allow_list = HashSet::from([
"array", "cmath", "gc", "hashlib", "_json", "_random", "math",
]);
let mut settings = vm::Settings::default();
settings.no_sig_int = true;
let interpreter = Rc::new(vm::Interpreter::with_init(settings, |vm| {
vm.add_native_modules(
rustpython_stdlib::get_module_inits()
.filter(|(k, _)| native_module_allow_list.contains(k.as_ref())),
);
vm.add_frozen(rustpython_pylib::FROZEN_STDLIB);
let _ = PyVector::make_class(&vm.ctx);
let _ = PyQueryEngine::make_class(&vm.ctx);
let _ = PyRecordBatch::make_class(&vm.ctx);
init_greptime_builtins("greptime", vm);
init_data_frame("data_frame", vm);
}));
interpreter
.enter(|vm| {
let sys = vm.sys_module.clone();
let version = sys.get_attr("version", vm)?.str(vm)?;
info!("Initialized RustPython interpreter {version}");
Ok::<(), PyBaseExceptionRef>(())
})
.expect("fail to display RustPython interpreter version");
interpreter
})
.clone()
})
}