use itertools::Itertools;
use reth_primitives::Address;
use tracing::{error, warn};
use super::{types::NodeWithDataRef, NodeData};
use crate::{
    normalized_actions::{MultiCallFrameClassification, NodeDataIndex, NormalizedAction},
    TreeSearchArgs, TreeSearchBuilder,
};
#[derive(Debug, Clone)]
pub struct Node {
    pub inner:         Vec<Node>,
    pub finalized:     bool,
    pub index:         u64,
    pub subactions:    Vec<usize>,
    pub trace_address: Vec<usize>,
    pub address:       Address,
    pub data:          usize,
}
impl Node {
    pub fn new(index: u64, address: Address, trace_address: Vec<usize>) -> Self {
        Self {
            index,
            trace_address,
            address,
            finalized: false,
            data: 0,
            inner: vec![],
            subactions: vec![],
        }
    }
    pub fn is_finalized(&self) -> bool {
        self.finalized
    }
    pub fn get_all_children_for_complex_classification<V: NormalizedAction>(
        &mut self,
        head: &MultiCallFrameClassification<V>,
        nodes: &mut NodeData<V>,
    ) {
        if head.trace_index == self.index {
            let mut results = Vec::new();
            self.collect(
                &mut results,
                head.collect_args(),
                &|data| {
                    (
                        NodeDataIndex {
                            trace_index:    data.node.index,
                            data_idx:       data.node.data as u64,
                            multi_data_idx: data.idx,
                        },
                        data.data.clone(),
                    )
                },
                nodes,
            );
            let this = nodes.get_mut(self.data).unwrap().first_mut().unwrap();
            let clear_collapsed_nodes = head.parse(this, results);
            clear_collapsed_nodes
                .into_iter()
                .sorted_unstable_by(|a, b| b.multi_data_idx.cmp(&a.multi_data_idx))
                .for_each(|index| {
                    self.clear_node_data(index, nodes);
                });
            return
        }
        if self.inner.len() <= 1 {
            if let Some(inner) = self.inner.first_mut() {
                return inner.get_all_children_for_complex_classification(head, nodes)
            }
            warn!("was not able to find node in tree for complex classification");
            return
        }
        let mut iter = self.inner.iter_mut();
        let mut cur_inner_node = iter.next().unwrap();
        let mut next_inner_node = iter.next().unwrap();
        for next_node in iter {
            if cur_inner_node.index == head.trace_index {
                return cur_inner_node.get_all_children_for_complex_classification(head, nodes)
            } else if next_inner_node.index == head.trace_index {
                return next_inner_node.get_all_children_for_complex_classification(head, nodes)
            }
            if next_inner_node.index <= head.trace_index {
                cur_inner_node = next_inner_node;
                next_inner_node = next_node;
            } else {
                return cur_inner_node.get_all_children_for_complex_classification(head, nodes)
            }
        }
        if cur_inner_node.index == head.trace_index {
            return cur_inner_node.get_all_children_for_complex_classification(head, nodes)
        } else if next_inner_node.index == head.trace_index {
            return next_inner_node.get_all_children_for_complex_classification(head, nodes)
        } else if next_inner_node.index > head.trace_index {
            return cur_inner_node.get_all_children_for_complex_classification(head, nodes)
        }
        else if let Some(last) = self.inner.last_mut() {
            return last.get_all_children_for_complex_classification(head, nodes)
        }
        warn!("was not able to find node in tree, should be unreachable");
    }
    pub fn modify_node_if_contains_childs<F, V: NormalizedAction>(
        &mut self,
        find: &TreeSearchBuilder<V>,
        modify: &F,
        data: &mut NodeData<V>,
    ) -> bool
    where
        F: Fn(&mut Node, &mut NodeData<V>),
    {
        let TreeSearchArgs { collect_current_node, child_node_to_collect, .. } =
            find.generate_search_args(self, &*data);
        if !child_node_to_collect {
            return false
        }
        let lower_classification_results = self
            .inner
            .iter_mut()
            .map(|node| node.modify_node_if_contains_childs(find, modify, data))
            .collect::<Vec<_>>();
        if !lower_classification_results.into_iter().any(|n| n) {
            if collect_current_node {
                modify(self, data);
                return true
            } else {
                return false
            }
        }
        false
    }
    pub fn modify_node_spans<F, V: NormalizedAction>(
        &mut self,
        find: &TreeSearchBuilder<V>,
        modify: &F,
        data: &mut NodeData<V>,
    ) -> bool
    where
        F: Fn(Vec<&mut Self>, &mut NodeData<V>),
    {
        if !find
            .generate_search_args(self, &*data)
            .child_node_to_collect
        {
            return false
        }
        let lower_has_better_collect = self
            .inner
            .iter_mut()
            .map(|n| n.modify_node_spans(find, modify, data))
            .collect::<Vec<_>>();
        let all_lower_better = lower_has_better_collect.into_iter().all(|t| t);
        if !all_lower_better {
            let mut nodes = vec![unsafe { &mut *(self as *mut Self) }];
            for i in &mut self.inner {
                nodes.push(i)
            }
            modify(nodes, data);
        }
        true
    }
    pub fn finalize(&mut self) {
        self.finalized = false;
        self.subactions = self.get_all_sub_actions();
        self.finalized = true;
        self.inner.iter_mut().for_each(|f| f.finalize());
    }
    pub fn insert<V: NormalizedAction>(
        &mut self,
        n: Node,
        data: Vec<V>,
        data_store: &mut NodeData<V>,
    ) {
        let trace_addr = n.trace_address.clone();
        self.get_all_inner_nodes(n, data, data_store, trace_addr);
    }
    pub fn get_all_inner_nodes<V: NormalizedAction>(
        &mut self,
        mut n: Node,
        data: Vec<V>,
        data_store: &mut NodeData<V>,
        mut trace_addr: Vec<usize>,
    ) {
        let revert = data_store
            .get_ref(self.data)
            .unwrap()
            .iter()
            .any(|n| n.get_action().is_revert());
        if revert {
            return
        }
        let log = trace_addr.clone();
        if trace_addr.len() == 1 {
            let idx = data_store.add(data);
            n.data = idx;
            self.inner.push(n);
        } else if let Some(inner) = self.inner.get_mut(trace_addr.remove(0)) {
            inner.get_all_inner_nodes(n, data, data_store, trace_addr)
        } else {
            error!("ERROR: {:?}\n {:?}", self.inner, log);
        }
    }
    pub fn get_all_sub_actions(&self) -> Vec<usize> {
        if self.finalized {
            self.subactions.clone()
        } else {
            let mut res = vec![self.data];
            res.extend(
                self.inner
                    .iter()
                    .flat_map(|inner| inner.get_all_sub_actions())
                    .collect::<Vec<_>>(),
            );
            res
        }
    }
    pub fn get_all_sub_actions_exclusive(&self) -> Vec<usize> {
        self.inner
            .iter()
            .flat_map(|inner| inner.get_all_sub_actions())
            .collect::<Vec<_>>()
    }
    pub fn get_last_create_call<V: NormalizedAction>(
        &self,
        start_index: &mut u64,
        data_store: &NodeData<V>,
    ) {
        if let Some(this_data) = data_store.get_ref(self.data) {
            for data in this_data {
                if data.is_create() && self.index > *start_index {
                    *start_index = self.index;
                }
            }
        }
        for i in &self.inner {
            i.get_last_create_call(start_index, data_store);
        }
    }
    pub fn get_all_parent_nodes_for_discovery(
        &self,
        res: &mut Vec<Node>,
        start_index: u64,
        trace_index: u64,
    ) {
        if self.index >= start_index && self.index < trace_index {
            res.push(self.clone());
            for i in &self.inner {
                i.get_all_parent_nodes_for_discovery(res, start_index, trace_index);
            }
        } else if self.index <= start_index && self.index < trace_index {
            for i in &self.inner {
                i.get_all_parent_nodes_for_discovery(res, start_index, trace_index);
            }
        }
    }
    pub fn get_immediate_parent_node(&self, tx_index: u64) -> Option<&Node> {
        if self.inner.last()?.index == tx_index {
            Some(self)
        } else {
            self.inner.last()?.get_immediate_parent_node(tx_index)
        }
    }
    pub fn tree_right_path(&self) -> Vec<Address> {
        self.inner
            .last()
            .map(|last| {
                let mut last = last.tree_right_path();
                last.push(self.address);
                last
            })
            .unwrap_or(vec![self.address])
    }
    pub fn all_sub_addresses(&self) -> Vec<Address> {
        self.inner
            .iter()
            .flat_map(|i| i.all_sub_addresses())
            .chain(vec![self.address])
            .collect()
    }
    pub fn current_call_stack(&self) -> Vec<Address> {
        let Some(mut stack) = self.inner.last().map(|n| n.current_call_stack()) else {
            return vec![self.address];
        };
        stack.push(self.address);
        stack
    }
    pub fn get_bounded_info<F, R>(&self, lower: u64, upper: u64, res: &mut Vec<R>, info_fn: &F)
    where
        F: Fn(&Node) -> R,
    {
        if self.index >= lower && self.index <= upper {
            res.push(info_fn(self));
        } else {
            return
        }
        self.inner
            .iter()
            .for_each(|node| node.get_bounded_info(lower, upper, res, info_fn));
    }
    pub fn clear_node_data<V: NormalizedAction>(
        &mut self,
        index: NodeDataIndex,
        data: &mut NodeData<V>,
    ) {
        if index.trace_index == self.index {
            data.get_mut(index.data_idx as usize)
                .unwrap()
                .remove(index.multi_data_idx);
            return
        }
        if self.inner.len() <= 1 {
            if let Some(inner) = self.inner.first_mut() {
                return inner.clear_node_data(index, data)
            }
            warn!("was not able to find node in tree for clearing node data");
            return
        }
        let mut iter = self.inner.iter_mut();
        let mut cur_inner_node = iter.next().unwrap();
        let mut next_inner_node = iter.next().unwrap();
        for next_node in iter {
            if cur_inner_node.index == index.trace_index {
                return cur_inner_node.clear_node_data(index, data)
            } else if next_inner_node.index == index.trace_index {
                return next_inner_node.clear_node_data(index, data)
            }
            if next_inner_node.index <= index.trace_index {
                cur_inner_node = next_inner_node;
                next_inner_node = next_node;
            } else {
                return cur_inner_node.clear_node_data(index, data)
            }
        }
        if cur_inner_node.index == index.trace_index {
            return cur_inner_node.clear_node_data(index, data)
        } else if next_inner_node.index == index.trace_index {
            return next_inner_node.clear_node_data(index, data)
        } else if next_inner_node.index > index.trace_index {
            return cur_inner_node.clear_node_data(index, data)
        } else if let Some(last) = self.inner.last_mut() {
            return last.clear_node_data(index, data)
        }
        warn!("was not able to find node in tree, should be unreachable");
    }
    pub fn remove_node_and_children<V: NormalizedAction>(
        &mut self,
        index: u64,
        data: &mut NodeData<V>,
    ) {
        if index == self.index {
            data.remove(self.data);
            self.get_all_sub_actions().into_iter().for_each(|f| {
                data.remove(f);
            });
            return
        }
        if self.inner.len() <= 1 {
            if let Some(inner) = self.inner.first_mut() {
                return inner.remove_node_and_children(index, data)
            }
            warn!("was not able to find node in tree for removing node data");
            return
        }
        let mut iter = self.inner.iter_mut();
        let mut cur_inner_node = iter.next().unwrap();
        let mut next_inner_node = iter.next().unwrap();
        for next_node in iter {
            if cur_inner_node.index == index {
                return cur_inner_node.remove_node_and_children(index, data)
            } else if next_inner_node.index == index {
                return next_inner_node.remove_node_and_children(index, data)
            }
            if next_inner_node.index <= index {
                cur_inner_node = next_inner_node;
                next_inner_node = next_node;
            } else {
                return cur_inner_node.remove_node_and_children(index, data)
            }
        }
        if cur_inner_node.index == index {
            return cur_inner_node.remove_node_and_children(index, data)
        } else if next_inner_node.index == index {
            return next_inner_node.remove_node_and_children(index, data)
        } else if next_inner_node.index > index {
            return cur_inner_node.remove_node_and_children(index, data)
        } else if let Some(last) = self.inner.last_mut() {
            return last.remove_node_and_children(index, data)
        }
        warn!("was not able to find node in tree, should be unreachable");
    }
    pub fn collect_spans<V: NormalizedAction>(
        &self,
        result: &mut Vec<Vec<V>>,
        call: &TreeSearchBuilder<V>,
        data: &NodeData<V>,
    ) -> bool {
        if !call.generate_search_args(self, data).child_node_to_collect {
            return false
        }
        let lower_has_better_collect = self
            .inner
            .iter()
            .map(|i| i.collect_spans(result, call, data))
            .collect::<Vec<bool>>();
        let lower_has_better = lower_has_better_collect.into_iter().all(|f| f);
        if !lower_has_better {
            let res = self
                .get_all_sub_actions()
                .into_iter()
                .filter_map(|node| data.get_ref(node).cloned())
                .flatten()
                .collect::<Vec<_>>();
            result.push(res);
        }
        true
    }
    pub fn collect<T, R, V: NormalizedAction>(
        &self,
        results: &mut Vec<R>,
        call: &TreeSearchBuilder<V>,
        wanted_data: &T,
        data: &NodeData<V>,
    ) where
        T: Fn(NodeWithDataRef<'_, V>) -> R,
    {
        let TreeSearchArgs { collect_current_node, child_node_to_collect, collect_idxs } =
            call.generate_search_args(self, data);
        if collect_current_node {
            if let Some(datas) = data.get_ref(self.data) {
                for idx in collect_idxs {
                    results.push(wanted_data(NodeWithDataRef::new(self, &datas[idx], idx)))
                }
            }
        }
        if child_node_to_collect {
            self.inner
                .iter()
                .for_each(|i| i.collect(results, call, wanted_data, data))
        }
    }
}