import plotly.express as px import plotly.graph_objects as go import plotly.colors as pc from scipy.stats import gaussian_kde import numpy as np import polars as pl import gradio as gr from math import ceil from translate import max_pitch_types, jp_pitch_to_en_pitch from data import ( df, pitch_stats, rhb_pitch_stats,lhb_pitch_stats, league_pitch_stats, rhb_league_pitch_stats, lhb_league_pitch_stats ) MAX_LOCS = len(jp_pitch_to_en_pitch) LOCS_PER_ROW = 4 MAX_ROWS = ceil(MAX_LOCS/LOCS_PER_ROW) INSUFFICIENT_PITCHES_MSG = 'No visualization: Not enough pitches thrown' INSUFFICIENT_PITCHES_MSG_MULTI_LINE = 'No visualization:
Not enough pitches thrown' # GRADIO FUNCTIONS def clone_if_dataframe(item): if isinstance(item, pl.DataFrame): # print(type(item)) return item.clone() else: return item def clone_df(fn): def _fn(*args, **kwargs): args = [clone_if_dataframe(arg) for arg in args] kwargs = {k: clone_if_dataframe(arg) for k, arg in kwargs.items()} return fn(*args, **kwargs) return _fn # location maps def fit_pred_kde(data, X, Y): kde = gaussian_kde(data) return kde(np.stack((X, Y)).reshape(2, -1)).reshape(*X.shape) plot_s = 256 sz_h = 200 sz_w = 160 h_h = 200 - 40*2 h_w = 160 - 32*2 kde_range = np.arange(-plot_s/2, plot_s/2, 1) X, Y = np.meshgrid( kde_range, kde_range ) def coordinatify(h, w): return dict( x0=-w/2, y0=-h/2, x1=w/2, y1=h/2 ) colorscale = pc.sequential.OrRd colorscale = [ [0, 'rgba(0, 0, 0, 0)'], ] + [ [i / len(colorscale), color] for i, color in enumerate(colorscale, start=1) ] @clone_df def plot_loc(df, league_df=None, min_pitches=3, max_pitches=5000): loc = df.select(['plate_x', 'plate_z']) fig = go.Figure() if len(loc) >= min_pitches: Z = fit_pred_kde(loc.to_numpy().T, X, Y) fig.add_shape( type="rect", **coordinatify(sz_h, sz_w), line_color='gray', # fillcolor='rgba(220, 220, 220, 0.75)', #gainsboro ) fig.add_shape( type="rect", **coordinatify(h_h, h_w), line_color='dimgray', ) fig.add_trace(go.Contour( z=Z, x=kde_range, y=kde_range, colorscale=colorscale, zmin=1e-5, zmax=Z.max(), contours={ 'start': 1e-5, 'end': Z.max(), 'size': Z.max() / 5 }, showscale=False )) else: fig.add_annotation( x=0, y=0, text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE, showarrow=False ) if league_df is not None: league_loc = league_df.select(pl.col('plate_x', 'plate_z')) if len(league_loc) > max_pitches: league_loc = league_loc.sample(max_pitches, seed=0) if len(league_loc) >= min_pitches: league_Z = fit_pred_kde(league_loc.to_numpy().T, X, Y) percentile = np.quantile(league_Z, 0.9) fig.add_trace(go.Contour( z=league_Z, x=kde_range, y=kde_range, colorscale=[ [0, 'rgba(0, 0, 0, 0)'], [1, 'rgba(0, 0, 0, 0)'] ], zmin=percentile, zmax=league_Z.max(), contours={ 'start': percentile, 'end': league_Z.max(), 'size': league_Z.max() - percentile, # 'coloring': 'heatmap' }, line={ 'width': 2, 'color': 'black', 'dash': 'dash' }, showlegend=True, showscale=False, name='NPB' )) fig.update_layout( xaxis=dict(range=[-plot_s/2, plot_s/2+1], showticklabels=False), yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1, showticklabels=False), legend=dict(orientation='h', y=0, yanchor='top'), # width=384, # height=384 ) return fig # velo distribution @clone_df def plot_velo(df=None, player=None, velos=None, pitch_type=None, pitch_name=None, min_pitches=2): assert not ((velos is None and player is None) or (velos is not None and player is not None)), 'exactly one of `player` or `velos` must be specified' if velos is None and player is not None: assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified' assert df is not None, '`df` must be provided if `velos` not provided' pitch_val = pitch_type or pitch_name pitch_col = 'pitch_type' if pitch_type else 'pitch_name' # velos = df.set_index(['name', pitch_col]).sort_index().loc[(player, pitch_val), 'release_speed'] velos = df.filter((pl.col('name') == player) & (pl.col(pitch_col) == pitch_val))['release_speed'] fig = go.Figure() if len(velos) >= min_pitches: fig = fig.add_trace(go.Violin(x=velos, side='positive', hoveron='points', points=False, meanline_visible=True, name='Velocity Distribution')) median = velos.median() x_range = [median-25, median+25] else: fig.add_annotation( x=(170+125)/2, y=0.3/2, text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE, showarrow=False, ) x_range = [125, 170] fig.update_layout( xaxis=dict( title='Velocity', range=x_range, scaleratio=2 ), yaxis=dict( title='Frequency', range=[0, 0.3], scaleanchor='x', scaleratio=1, tickvals=np.linspace(0, 0.3, 3), ticktext=np.linspace(0, 0.3, 3), ), autosize=True, # width=512, # height=256, modebar_remove=['zoom', 'autoScale', 'resetScale'], ) return fig @clone_df def plot_velo_summary(df, league_df, player): min_pitches = 2 # player_df = df.set_index('name').sort_index().loc[player].sort_values('pitch_name').set_index('pitch_name') # pitch_counts = player_df.index.value_counts(ascending=True) player_df = df.filter(pl.col('release_speed').is_not_null()) pitch_counts = player_df['pitch_name'].value_counts().sort('count') # league_df = df.set_index('pitch_name').sort_index() league_df = league_df.filter(pl.col('release_speed').is_not_null()) fig = go.Figure() velo_center = (player_df['release_speed'].min() + player_df['release_speed'].max()) / 2 # for i, (pitch_name, count) in enumerate(pitch_counts.items()): for i, (pitch_name, count) in enumerate(pitch_counts.iter_rows()): # velos = player_df.loc[pitch_name, 'release_speed'] # league_velos = league_df.loc[pitch_name, 'release_speed'] velos = player_df.filter(pl.col('pitch_name') == pitch_name)['release_speed'] league_velos = league_df.filter(pl.col('pitch_name') == pitch_name)['release_speed'] fig.add_trace(go.Violin( x=league_velos, y=[pitch_name]*len(league_velos), line_color='gray', side='positive', orientation='h', meanline_visible=True, points=False, legendgroup='NPB', legendrank=1, # visible='legendonly', # showlegend=False, showlegend=i==0, name='NPB', )) if count >= min_pitches: fig.add_trace(go.Violin( x=velos, y=[pitch_name]*len(velos), side='positive', orientation='h', meanline_visible=True, points=False, legendgroup=pitch_name, legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i), name=pitch_name )) else: fig.add_trace(go.Scatter( x=[velo_center], y=[pitch_name], text=[INSUFFICIENT_PITCHES_MSG], textposition='top center', hovertext=False, mode="lines+text", legendgroup=pitch_name, legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i), name=pitch_name, )) # fig.add_trace(go.Violin( # x=league_df['release_speed'], # y=[player]*len(league_df), # line_color='gray', # side='positive', # orientation='h', # meanline_visible=True, # points=False, # legendgroup='NPB', # legendrank=1, # # visible='legendonly', # name='NPB', # )) # fig.add_trace(go.Violin( # x=player_df['release_speed'], # y=[player]*len(player_df), # side='positive', # orientation='h', # meanline_visible=True, # points=False, # legendrank=0, # name=player # )) # fig.update_xaxes(title='Velocity', range=[player_df['release_speed'].dropna().min() - 2, player_df['release_speed'].dropna().max() + 2]) fig.update_xaxes(title='Velocity', range=[player_df['release_speed'].min() - 2, player_df['release_speed'].max() + 2]) # fig.update_yaxes(range=[0, len(pitch_counts)+1-0.25], visible=False) fig.update_yaxes(range=[0, len(pitch_counts)-0.25], visible=False) fig.update_layout( violingap=0, violingroupgap=0, legend=dict(orientation='h', y=-0.15, yanchor='top'), modebar_remove=['zoom', 'select2d', 'lasso2d', 'pan', 'autoScale'], dragmode=False ) return fig def update_dfs(player, handedness, df): if handedness == 'Both': handedness_filter = pl.col('stand').is_in(['R', 'L']) _pitch_stats = pitch_stats _league_pitch_stats = league_pitch_stats elif handedness == 'Right': handedness_filter = pl.col('stand') == 'R' _pitch_stats = rhb_pitch_stats _league_pitch_stats = rhb_league_pitch_stats elif handedness == 'Left': handedness_filter = pl.col('stand') == 'L' _pitch_stats = lhb_pitch_stats _league_pitch_stats = lhb_league_pitch_stats player_filter = pl.col('name') == player final_filter = player_filter & handedness_filter return df.filter(final_filter), df.filter(handedness_filter), _pitch_stats.filter(player_filter), _league_pitch_stats, def set_download_file(df): file_path = 'files/npb.csv' df.write_csv(file_path) return file_path def preview_df(df): return df.head() @clone_df def plot_usage(df, player): fig = px.pie(df.select('pitch_name'), names='pitch_name') fig.update_traces(texttemplate='%{percent:.1%}', hovertemplate=f'{player}
' + 'threw a %{label}
%{percent:.1%} of the time (%{value} pitches)') return fig @clone_df def plot_pitch_cards(df, league_df, pitch_stats): pitch_counts = df['pitch_name'].value_counts().sort('count', descending=True) pitch_rows = [] pitch_groups = [] pitch_names = [] pitch_infos = [] pitch_velos = [] pitch_locs = [] for row in range(ceil(len(pitch_counts) / LOCS_PER_ROW)): pitch_rows.append(gr.update(visible=True)) for row in range(len(pitch_rows), MAX_ROWS): pitch_rows.append(gr.update(visible=False)) for pitch_name, count in pitch_counts.iter_rows(): pitch_groups.append(gr.update(visible=True)) pitch_names.append(gr.update(value=f'### {pitch_name}', visible=True)) pitch_infos.append(gr.update( value=pitch_stats.filter(pl.col('pitch_name') == pitch_name).select(['Whiff%', 'CSW%']), visible=True )) pitch_velos.append(gr.update( value=plot_velo(velos=df.filter((pl.col('pitch_name') == pitch_name) & (pl.col('release_speed').is_not_null()))['release_speed']), visible=True )) pitch_locs.append(gr.update( value=plot_loc( df=df.filter(pl.col('pitch_name') == pitch_name), league_df=league_df.filter(pl.col('pitch_name') == pitch_name) ), label='Pitch location', visible=True )) for _ in range(max_pitch_types - len(pitch_names)): pitch_groups.append(gr.update(visible=False)) pitch_names.append(gr.update(value=None, visible=False)) pitch_infos.append(gr.update(value=None, visible=False)) pitch_velos.append(gr.update(value=None, visible=False)) pitch_locs.append(gr.update(value=None, visible=False)) return pitch_rows + pitch_groups + pitch_names + pitch_infos + pitch_velos + pitch_locs @clone_df def update_velo_stats(pitch_stats, league_pitch_stats): return ( pitch_stats .select(pl.col('pitch_name').alias('Pitch'), pl.col('Velocity').alias('Avg. Velo'), pl.col('Count')) .join( league_pitch_stats.select(pl.col('pitch_name').alias('Pitch'), pl.col('Velocity').alias('League Avg. Velo')), on='Pitch', how='inner' ) .sort('Count', descending=True) .drop('Count') )