def break_down_strips(df_opt_marks):
    """
    Break down overlapping strips into granular strips
    """
    df_all = df_opt_marks.copy()
    df_sorted = df_all.sort_values(['Source', 'Sink', 'TimeOfUse', 'StartDate'])
    new_strips = []
    
    for (src, sink, tou), group in df_sorted.groupby(['Source', 'Sink', 'TimeOfUse']):
        if len(group) < 2:
            continue
            
        for i in range(len(group)-1):
            strip1 = group.iloc[i]
            strip2 = group.iloc[i+1]
            
            # Check for overlap
            if strip1['EndDate'] >= strip2['StartDate']:
                # Get hours for the period before Strip2
                hrs_before_df = get_peak_hours(strip1['StartDate'], strip2['StartDate'] - timedelta(days=1))
                hrs_before = hrs_before_df.loc[hrs_before_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                
                if strip1['EndDate'] == strip2['EndDate']:
                    # Case 1: Strips end at same time
                    hrs_total_df = get_peak_hours(strip1['StartDate'], strip1['EndDate'])
                    hrs_overlap_df = get_peak_hours(strip2['StartDate'], strip1['EndDate'])
                    
                    hrs_total = hrs_total_df.loc[hrs_total_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    hrs_overlap = hrs_overlap_df.loc[hrs_overlap_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    
                    p_before = (strip1['ShadowPricePerMWH'] * hrs_total - strip2['ShadowPricePerMWH'] * hrs_overlap) / hrs_before
                    
                elif strip1['EndDate'] > strip2['EndDate']:
                    # Case 2: Strip1 extends beyond Strip2
                    hrs_overlap_df = get_peak_hours(strip2['StartDate'], strip2['EndDate'])
                    hrs_after_df = get_peak_hours(strip2['EndDate'] + timedelta(days=1), strip1['EndDate'])
                    
                    hrs_overlap = hrs_overlap_df.loc[hrs_overlap_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    hrs_after = hrs_after_df.loc[hrs_after_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    
                    # Create strip for period before Strip2
                    p_before = (strip1['ShadowPricePerMWH'] * (hrs_before + hrs_overlap + hrs_after) - 
                              strip2['ShadowPricePerMWH'] * hrs_overlap) / hrs_before
                    
                    # Create strip for period after Strip2
                    p_after = strip1['ShadowPricePerMWH']  # Original price applies here
                    
                    # Add the after strip
                    new_strip_after = strip1.copy()
                    new_strip_after['StartDate'] = strip2['EndDate'] + timedelta(days=1)
                    new_strip_after['EndDate'] = strip1['EndDate']
                    new_strip_after['ShadowPricePerMWH'] = p_after
                    new_strips.append(new_strip_after)
                    
                else:
                    # Case 3: Strip2 extends beyond Strip1
                    hrs_total_df = get_peak_hours(strip1['StartDate'], strip1['EndDate'])
                    hrs_overlap_df = get_peak_hours(strip2['StartDate'], strip1['EndDate'])
                    
                    hrs_total = hrs_total_df.loc[hrs_total_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    hrs_overlap = hrs_overlap_df.loc[hrs_overlap_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    
                    p_before = (strip1['ShadowPricePerMWH'] * hrs_total - strip2['ShadowPricePerMWH'] * hrs_overlap) / hrs_before
                
                # Add the before strip (common to all cases)
                new_strip = strip1.copy()
                new_strip['StartDate'] = strip1['StartDate']
                new_strip['EndDate'] = strip2['StartDate'] - timedelta(days=1)
                new_strip['ShadowPricePerMWH'] = p_before
                new_strips.append(new_strip)
    
    df_opt_marks_all = pd.concat([df_all] + [pd.DataFrame([strip]) for strip in new_strips], 
                                ignore_index=True)
    
    return df_opt_marks_all.sort_values(['Source', 'Sink', 'TimeOfUse', 'StartDate'])
    
    
    def break_down_strips(df_opt_marks):
    """
    Break down overlapping strips into granular strips
    """
    df_all = df_opt_marks.copy()
    df_sorted = df_all.sort_values(['Source', 'Sink', 'TimeOfUse', 'StartDate'])
    new_strips = []
    
    for (src, sink, tou), group in df_sorted.groupby(['Source', 'Sink', 'TimeOfUse']):
        if len(group) < 2:
            continue
            
        for i in range(len(group)-1):
            strip1 = group.iloc[i]
            strip2 = group.iloc[i+1]
            
            if strip1['StartDate'] == strip2['StartDate']:
                # Case 4: Same start dates
                if strip1['EndDate'] < strip2['EndDate']:
                    # Create strip for period after strip1's end
                    hrs_overlap_df = get_peak_hours(strip1['StartDate'], strip1['EndDate'])
                    hrs_remaining_df = get_peak_hours(strip1['EndDate'] + timedelta(days=1), strip2['EndDate'])
                    
                    hrs_overlap = hrs_overlap_df.loc[hrs_overlap_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    hrs_remaining = hrs_remaining_df.loc[hrs_remaining_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    
                    # Add strip for the non-overlapping period
                    new_strip = strip2.copy()
                    new_strip['StartDate'] = strip1['EndDate'] + timedelta(days=1)
                    new_strip['EndDate'] = strip2['EndDate']
                    new_strips.append(new_strip)
                    
                elif strip1['EndDate'] > strip2['EndDate']:
                    # Create strip for period after strip2's end
                    hrs_overlap_df = get_peak_hours(strip1['StartDate'], strip2['EndDate'])
                    hrs_remaining_df = get_peak_hours(strip2['EndDate'] + timedelta(days=1), strip1['EndDate'])
                    
                    hrs_overlap = hrs_overlap_df.loc[hrs_overlap_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    hrs_remaining = hrs_remaining_df.loc[hrs_remaining_df['peak_type'] == strip1['TimeOfUse'], 'n_hrs'].iloc[0]
                    
                    # Add strip for the non-overlapping period
                    new_strip = strip1.copy()
                    new_strip['StartDate'] = strip2['EndDate'] + timedelta(days=1)
                    new_strip['EndDate'] = strip1['EndDate']
                    new_strips.append(new_strip)
                continue

            # [Rest of the code for Cases 1-3...]

    
import pandas as pd
import numpy as np

def add_factor_spreads(df, df_imp):
    """
    Add top 3 and bottom 3 factor spreads between source and sink nodes.
    
    Parameters:
    df (pandas.DataFrame): DataFrame with 'source' and 'sink' columns
    df_imp (pandas.DataFrame): DataFrame with factors as rows and nodes as columns
        First column should be 'factor', other columns are node names
    
    Returns:
    pandas.DataFrame: Original df with 6 new columns for factor spreads
    """
    # Create a copy of input DataFrame
    result_df = df.copy()
    
    # Get list of factors (excluding the 'factor' column)
    factors = df_imp['factor'].values
    
    # Function to calculate factor spreads for a single source-sink pair
    def get_factor_spreads(row):
        source = row['source']
        sink = row['sink']
        
        # Calculate spreads for all factors
        spreads = []
        for factor in factors:
            factor_row = df_imp[df_imp['factor'] == factor]
            source_value = factor_row[source].iloc[0]
            sink_value = factor_row[sink].iloc[0]
            spread = source_value - sink_value
            spreads.append((factor, spread))
        
        # Sort spreads by absolute value
        spreads.sort(key=lambda x: abs(x[1]), reverse=True)
        
        # Get top 3 and bottom 3
        top_3 = spreads[:3]
        bottom_3 = spreads[-3:]
        
        return pd.Series({
            'top1_factor': top_3[0][0],
            'top1_spread': top_3[0][1],
            'top2_factor': top_3[1][0],
            'top2_spread': top_3[1][1],
            'top3_factor': top_3[2][0],
            'top3_spread': top_3[2][1],
            'bottom1_factor': bottom_3[-1][0],
            'bottom1_spread': bottom_3[-1][1],
            'bottom2_factor': bottom_3[-2][0],
            'bottom2_spread': bottom_3[-2][1],
            'bottom3_factor': bottom_3[-3][0],
            'bottom3_spread': bottom_3[-3][1]
        })
    
    # Calculate spreads for all rows
    spreads_df = df.apply(get_factor_spreads, axis=1)
    
    # Add new columns to result DataFrame
    result_df = pd.concat([result_df, spreads_df], axis=1)
    
    return result_df
# Example usage
df = pd.DataFrame({
    'source': ['NodeA', 'NodeB'],
    'sink': ['NodeC', 'NodeD']
})

df_imp = pd.DataFrame({
    'factor': ['factor1', 'factor2', 'factor3', 'factor4'],
    'NodeA': [1.0, 2.0, 3.0, 4.0],
    'NodeB': [2.0, 3.0, 4.0, 5.0],
    'NodeC': [3.0, 4.0, 5.0, 6.0],
    'NodeD': [4.0, 5.0, 6.0, 7.0]
})

result = add_factor_spreads(df, df_imp)    
    
    
from dash import Dash, dcc, html, Input, Output, State
import dash_bootstrap_components as dbc

app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Sample data
options = [
    {'label': f'Option {i}', 'value': f'opt_{i}'} for i in range(1, 21)
]

app.layout = html.Div([
    dbc.Container([
        html.H3("Filtered Multi-Select Dropdown"),
        dbc.Row([
            dbc.Col([
                dcc.Dropdown(
                    id='multi-dropdown',
                    options=options,
                    multi=True,
                    placeholder="Search and select options...",
                    searchable=True,
                    value=[]
                ),
                dbc.Button(
                    "Select All Filtered",
                    id="select-all-button",
                    color="primary",
                    className="mt-2"
                ),
                html.Div(id='output-container', className="mt-3")
            ])
        ])
    ])
])

@app.callback(
    Output('multi-dropdown', 'value'),
    Input('select-all-button', 'n_clicks'),
    State('multi-dropdown', 'options'),
    State('multi-dropdown', 'search_value'),
    State('multi-dropdown', 'value'),
    prevent_initial_call=True
)
def select_all_filtered(n_clicks, options, search_value, current_values):
    if not n_clicks:
        return current_values
    
    # If there's a search value, filter options
    if search_value:
        filtered_options = [
            opt['value'] for opt in options 
            if search_value.lower() in opt['label'].lower()
        ]
        # If current selection contains values not in filter, preserve them
        preserved_values = [
            val for val in current_values 
            if val not in filtered_options
        ]
        return preserved_values + filtered_options
    
    # If no search value, select all options
    return [opt['value'] for opt in options]

@app.callback(
    Output('output-container', 'children'),
    Input('multi-dropdown', 'value')
)
def display_selected(selected_values):
    if not selected_values:
        return "No options selected"
    return f"Selected: {', '.join(selected_values)}"

if __name__ == '__main__':
    app.run_server(debug=True)