"""
Module for grouping and analyzing wallet transactions from ClickHouse
Provides functions to aggregate transaction data by token and calculate summary statistics
"""

from collections import defaultdict
from typing import Dict, List, Any, Optional
from utils.config import *
from utils.logger import *
from datetime import date, datetime
from utils.StM import *
from utils.ClickHouseManager import ClickHouseManager


class TransactionGrouper:
    """Class to handle transaction grouping and analysis"""
    
    def __init__(self):
        self.clickhouse_manager = ClickHouseManager()
    
    def get_wallet_transactions(self, wallet_address: str) -> List[Dict[str, Any]]:
        """
        Query all transactions for a specific wallet from ClickHouse
        
        Args:
            wallet_address (str): The wallet address to query transactions for
            
        Returns:
            List[Dict]: List of transaction dictionaries
        """
        try:
            query = """
                SELECT 
                    wallet_address,
                    token_extracted,
                    signature,
                    type,
                    profit,
                    blocktime,
                    delta_sol,
                    delta_token,
                    fee,
                    MC
                FROM wallet_transactions 
                WHERE wallet_address = %(wallet_address)s
                ORDER BY blocktime ASC
            """
            
            params = {'wallet_address': wallet_address}
            rows = self.clickhouse_manager.execute_query(query, params)
            
            # Convert to list of dictionaries
            columns = [
                'wallet_address', 'token_extracted', 'signature', 'type', 
                'profit', 'blocktime', 'delta_sol', 'delta_token', 'fee', 'MC'
            ]
            
            transactions = []
            for row in rows:
                tx_dict = dict(zip(columns, row))
                transactions.append(tx_dict)
            
            return transactions
            
        except Exception as e:
            db_logger.error(f"Error querying wallet transactions: {e}", exc_info=True)
            return []
    
    def group_transactions_by_token(self, transactions: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
        """
        Group transactions by token and calculate summary statistics
        
        Args:
            transactions (List[Dict]): List of transaction dictionaries
            
        Returns:
            Dict: Grouped transactions with statistics per token
        """
        grouped = defaultdict(lambda: {
            'transactions': [],
            'total_profit': 0.0,
            'total_buys': 0,
            'total_sells': 0,
            'total_buy_amount': 0.0,
            'total_sell_amount': 0.0,
            'total_delta_sol': 0.0,
            'total_delta_token': 0.0,
            'total_fees': 0.0,
            'first_transaction_time': None,
            'last_transaction_time': None,
            'unique_signatures': set()
        })
        
        for tx in transactions:
            token = tx.get('token_extracted', 'UNKNOWN')
            tx_type = tx.get('type', '').lower()
            profit = tx.get('profit') or 0.0
            delta_sol = tx.get('delta_sol', 0.0)
            delta_token = tx.get('delta_token', 0.0)
            fee = tx.get('fee', 0.0)
            blocktime = tx.get('blocktime')
            signature = tx.get('signature', '')
            
            # Add transaction to group
            grouped[token]['transactions'].append(tx)
            
            # Track unique signatures to avoid double counting
            grouped[token]['unique_signatures'].add(signature)
            
            # Update profit
            if profit is not None:
                grouped[token]['total_profit'] += float(profit)
            
            # Count transaction types
            if tx_type == 'buy':
                grouped[token]['total_buys'] += 1
                grouped[token]['total_buy_amount'] += abs(delta_token)
            elif tx_type == 'sell':
                grouped[token]['total_sells'] += 1
                grouped[token]['total_sell_amount'] += abs(delta_token)
            
            # Update totals
            grouped[token]['total_delta_sol'] += delta_sol
            grouped[token]['total_delta_token'] += delta_token
            grouped[token]['total_fees'] += fee
            
            # Track time range
            if blocktime:
                if grouped[token]['first_transaction_time'] is None:
                    grouped[token]['first_transaction_time'] = blocktime
                else:
                    grouped[token]['first_transaction_time'] = min(
                        grouped[token]['first_transaction_time'], blocktime
                    )
                
                if grouped[token]['last_transaction_time'] is None:
                    grouped[token]['last_transaction_time'] = blocktime
                else:
                    grouped[token]['last_transaction_time'] = max(
                        grouped[token]['last_transaction_time'], blocktime
                    )
        
        # Convert defaultdict to regular dict and clean up
        result = {}
        for token, data in grouped.items():
            # Convert set to count for JSON serialization
            data['unique_transaction_count'] = len(data['unique_signatures'])
            del data['unique_signatures']
            
            # Calculate additional metrics
            data['total_transactions'] = len(data['transactions'])
            data['net_token_change'] = data['total_delta_token']
            data['net_sol_change'] = data['total_delta_sol']
            
            # Calculate average profit per transaction
            if data['total_transactions'] > 0:
                data['avg_profit_per_transaction'] = data['total_profit'] / data['total_transactions']
            else:
                data['avg_profit_per_transaction'] = 0.0
            
            result[token] = data
        
        return result
    
    def get_wallet_transaction_summary(self, wallet_address: str) -> Dict[str, Any]:
        """
        Get complete transaction summary for a wallet
        
        Args:
            wallet_address (str): The wallet address to analyze
            
        Returns:
            Dict: Complete transaction summary grouped by token
        """
        try:
            # Get all transactions for the wallet
            transactions = self.get_wallet_transactions(wallet_address)
            
            if not transactions:
                return {
                    'wallet_address': wallet_address,
                    'total_transactions': 0,
                    'tokens': {},
                    'overall_stats': {
                        'total_profit': 0.0,
                        'total_buys': 0,
                        'total_sells': 0,
                        'total_fees': 0.0,
                        'unique_tokens': 0
                    }
                }
            
            # Group transactions by token
            grouped_transactions = self.group_transactions_by_token(transactions)
            
            # Calculate overall statistics
            overall_stats = {
                'total_profit': sum(token_data['total_profit'] for token_data in grouped_transactions.values()),
                'total_buys': sum(token_data['total_buys'] for token_data in grouped_transactions.values()),
                'total_sells': sum(token_data['total_sells'] for token_data in grouped_transactions.values()),
                'total_fees': sum(token_data['total_fees'] for token_data in grouped_transactions.values()),
                'unique_tokens': len(grouped_transactions),
                'total_transactions': len(transactions)
            }
            
            # Format response
            response = {
                'wallet_address': wallet_address,
                'total_transactions': len(transactions),
                'tokens': grouped_transactions,
                'overall_stats': overall_stats
            }
            
            return response
            
        except Exception as e:
            db_logger.error(f"Error generating wallet transaction summary: {e}", exc_info=True)
            return {
                'wallet_address': wallet_address,
                'error': str(e),
                'total_transactions': 0,
                'tokens': {},
                'overall_stats': {}
            }
    
    def get_wallet_token_performance(self, wallet_address: str, token_address: Optional[str] = None) -> Dict[str, Any]:
        """
        Get detailed performance metrics for a wallet's token trades
        
        Args:
            wallet_address (str): The wallet address to analyze
            token_address (str, optional): Specific token to analyze, if None analyzes all tokens
            
        Returns:
            Dict: Performance metrics including profit/loss, win rate, etc.
        """
        try:
            transactions = self.get_wallet_transactions(wallet_address)
            
            if token_address:
                # Filter for specific token
                transactions = [tx for tx in transactions if tx.get('token_extracted') == token_address]
            
            if not transactions:
                return {'error': 'No transactions found'}
            
            grouped = self.group_transactions_by_token(transactions)
            
            # Calculate performance metrics for each token
            performance_data = {}
            for token, data in grouped.items():
                profit = data['total_profit']
                buys = data['total_buys']
                sells = data['total_sells']
                
                # Calculate win rate (simplified: positive profit = win)
                win_rate = 0.0
                if sells > 0:
                    winning_trades = sum(1 for tx in data['transactions'] 
                                       if tx.get('type', '').lower() == 'sell' and 
                                          (tx.get('profit') or 0) > 0)
                    win_rate = (winning_trades / sells) * 100
                
                performance_data[token] = {
                    'total_profit': profit,
                    'total_buys': buys,
                    'total_sells': sells,
                    'win_rate_percent': win_rate,
                    'avg_profit_per_trade': data['avg_profit_per_transaction'],
                    'total_fees': data['total_fees'],
                    'net_profit_after_fees': profit - data['total_fees'],
                    'transaction_count': data['total_transactions']
                }
            
            return {
                'wallet_address': wallet_address,
                'token_performance': performance_data
            }
            
        except Exception as e:
            db_logger.error(f"Error calculating wallet token performance: {e}", exc_info=True)
            return {'error': str(e)}


# Global instance for use in web.py
transaction_grouper = TransactionGrouper()


def get_wallet_transactions_grouped(wallet_address: str) -> Dict[str, Any]:
    """
    Convenience function for web.py to get grouped wallet transactions
    
    Args:
        wallet_address (str): The wallet address to analyze
        
    Returns:
        Dict: Grouped transaction summary
    """
    return transaction_grouper.get_wallet_transaction_summary(wallet_address)


def get_wallet_performance_metrics(wallet_address: str, token_address: Optional[str] = None) -> Dict[str, Any]:
    """
    Convenience function for web.py to get wallet performance metrics
    
    Args:
        wallet_address (str): The wallet address to analyze
        token_address (str, optional): Specific token to analyze
        
    Returns:
        Dict: Performance metrics
    """
    return transaction_grouper.get_wallet_token_performance(wallet_address, token_address)