import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from scipy.stats import gaussian_kde
import numpy as np
import polars as pl
import gradio as gr
from math import ceil
from translate import max_pitch_types, jp_pitch_to_en_pitch
from data import (
df,
pitch_stats, rhb_pitch_stats,lhb_pitch_stats,
league_pitch_stats, rhb_league_pitch_stats, lhb_league_pitch_stats
)
MAX_LOCS = len(jp_pitch_to_en_pitch)
LOCS_PER_ROW = 4
MAX_ROWS = ceil(MAX_LOCS/LOCS_PER_ROW)
INSUFFICIENT_PITCHES_MSG = 'No visualization: Not enough pitches thrown'
INSUFFICIENT_PITCHES_MSG_MULTI_LINE = 'No visualization:
Not enough pitches thrown'
# GRADIO FUNCTIONS
def clone_if_dataframe(item):
if isinstance(item, pl.DataFrame):
# print(type(item))
return item.clone()
else:
return item
def clone_df(fn):
def _fn(*args, **kwargs):
args = [clone_if_dataframe(arg) for arg in args]
kwargs = {k: clone_if_dataframe(arg) for k, arg in kwargs.items()}
return fn(*args, **kwargs)
return _fn
# location maps
def fit_pred_kde(data, X, Y):
kde = gaussian_kde(data)
return kde(np.stack((X, Y)).reshape(2, -1)).reshape(*X.shape)
plot_s = 256
sz_h = 200
sz_w = 160
h_h = 200 - 40*2
h_w = 160 - 32*2
kde_range = np.arange(-plot_s/2, plot_s/2, 1)
X, Y = np.meshgrid(
kde_range,
kde_range
)
def coordinatify(h, w):
return dict(
x0=-w/2,
y0=-h/2,
x1=w/2,
y1=h/2
)
colorscale = pc.sequential.OrRd
colorscale = [
[0, 'rgba(0, 0, 0, 0)'],
] + [
[i / len(colorscale), color] for i, color in enumerate(colorscale, start=1)
]
@clone_df
def plot_loc(df, league_df=None, min_pitches=3, max_pitches=5000):
loc = df.select(['plate_x', 'plate_z'])
fig = go.Figure()
if len(loc) >= min_pitches:
Z = fit_pred_kde(loc.to_numpy().T, X, Y)
fig.add_shape(
type="rect",
**coordinatify(sz_h, sz_w),
line_color='gray',
# fillcolor='rgba(220, 220, 220, 0.75)', #gainsboro
)
fig.add_shape(
type="rect",
**coordinatify(h_h, h_w),
line_color='dimgray',
)
fig.add_trace(go.Contour(
z=Z,
x=kde_range,
y=kde_range,
colorscale=colorscale,
zmin=1e-5,
zmax=Z.max(),
contours={
'start': 1e-5,
'end': Z.max(),
'size': Z.max() / 5
},
showscale=False
))
else:
fig.add_annotation(
x=0,
y=0,
text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE,
showarrow=False
)
if league_df is not None:
league_loc = league_df.select(pl.col('plate_x', 'plate_z'))
if len(league_loc) > max_pitches:
league_loc = league_loc.sample(max_pitches, seed=0)
if len(league_loc) >= min_pitches:
league_Z = fit_pred_kde(league_loc.to_numpy().T, X, Y)
percentile = np.quantile(league_Z, 0.9)
fig.add_trace(go.Contour(
z=league_Z,
x=kde_range,
y=kde_range,
colorscale=[
[0, 'rgba(0, 0, 0, 0)'],
[1, 'rgba(0, 0, 0, 0)']
],
zmin=percentile,
zmax=league_Z.max(),
contours={
'start': percentile,
'end': league_Z.max(),
'size': league_Z.max() - percentile,
# 'coloring': 'heatmap'
},
line={
'width': 2,
'color': 'black',
'dash': 'dash'
},
showlegend=True,
showscale=False,
name='NPB'
))
fig.update_layout(
xaxis=dict(range=[-plot_s/2, plot_s/2+1], showticklabels=False),
yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1, showticklabels=False),
legend=dict(orientation='h', y=0, yanchor='top'),
# width=384,
# height=384
)
return fig
# velo distribution
@clone_df
def plot_velo(df=None, player=None, velos=None, pitch_type=None, pitch_name=None, min_pitches=2):
assert not ((velos is None and player is None) or (velos is not None and player is not None)), 'exactly one of `player` or `velos` must be specified'
if velos is None and player is not None:
assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified'
assert df is not None, '`df` must be provided if `velos` not provided'
pitch_val = pitch_type or pitch_name
pitch_col = 'pitch_type' if pitch_type else 'pitch_name'
# velos = df.set_index(['name', pitch_col]).sort_index().loc[(player, pitch_val), 'release_speed']
velos = df.filter((pl.col('name') == player) & (pl.col(pitch_col) == pitch_val))['release_speed']
fig = go.Figure()
if len(velos) >= min_pitches:
fig = fig.add_trace(go.Violin(x=velos, side='positive', hoveron='points', points=False, meanline_visible=True, name='Velocity Distribution'))
median = velos.median()
x_range = [median-25, median+25]
else:
fig.add_annotation(
x=(170+125)/2,
y=0.3/2,
text=INSUFFICIENT_PITCHES_MSG_MULTI_LINE,
showarrow=False,
)
x_range = [125, 170]
fig.update_layout(
xaxis=dict(
title='Velocity',
range=x_range,
scaleratio=2
),
yaxis=dict(
title='Frequency',
range=[0, 0.3],
scaleanchor='x',
scaleratio=1,
tickvals=np.linspace(0, 0.3, 3),
ticktext=np.linspace(0, 0.3, 3),
),
autosize=True,
# width=512,
# height=256,
modebar_remove=['zoom', 'autoScale', 'resetScale'],
)
return fig
@clone_df
def plot_velo_summary(df, league_df, player):
min_pitches = 2
# player_df = df.set_index('name').sort_index().loc[player].sort_values('pitch_name').set_index('pitch_name')
# pitch_counts = player_df.index.value_counts(ascending=True)
player_df = df.filter(pl.col('release_speed').is_not_null())
pitch_counts = player_df['pitch_name'].value_counts().sort('count')
# league_df = df.set_index('pitch_name').sort_index()
league_df = league_df.filter(pl.col('release_speed').is_not_null())
fig = go.Figure()
velo_center = (player_df['release_speed'].min() + player_df['release_speed'].max()) / 2
# for i, (pitch_name, count) in enumerate(pitch_counts.items()):
for i, (pitch_name, count) in enumerate(pitch_counts.iter_rows()):
# velos = player_df.loc[pitch_name, 'release_speed']
# league_velos = league_df.loc[pitch_name, 'release_speed']
velos = player_df.filter(pl.col('pitch_name') == pitch_name)['release_speed']
league_velos = league_df.filter(pl.col('pitch_name') == pitch_name)['release_speed']
fig.add_trace(go.Violin(
x=league_velos,
y=[pitch_name]*len(league_velos),
line_color='gray',
side='positive',
orientation='h',
meanline_visible=True,
points=False,
legendgroup='NPB',
legendrank=1,
# visible='legendonly',
# showlegend=False,
showlegend=i==0,
name='NPB',
))
if count >= min_pitches:
fig.add_trace(go.Violin(
x=velos,
y=[pitch_name]*len(velos),
side='positive',
orientation='h',
meanline_visible=True,
points=False,
legendgroup=pitch_name,
legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i),
name=pitch_name
))
else:
fig.add_trace(go.Scatter(
x=[velo_center],
y=[pitch_name],
text=[INSUFFICIENT_PITCHES_MSG],
textposition='top center',
hovertext=False,
mode="lines+text",
legendgroup=pitch_name,
legendrank=len(pitch_counts) - i, #2+(len(pitch_counts) - i),
name=pitch_name,
))
# fig.add_trace(go.Violin(
# x=league_df['release_speed'],
# y=[player]*len(league_df),
# line_color='gray',
# side='positive',
# orientation='h',
# meanline_visible=True,
# points=False,
# legendgroup='NPB',
# legendrank=1,
# # visible='legendonly',
# name='NPB',
# ))
# fig.add_trace(go.Violin(
# x=player_df['release_speed'],
# y=[player]*len(player_df),
# side='positive',
# orientation='h',
# meanline_visible=True,
# points=False,
# legendrank=0,
# name=player
# ))
# fig.update_xaxes(title='Velocity', range=[player_df['release_speed'].dropna().min() - 2, player_df['release_speed'].dropna().max() + 2])
fig.update_xaxes(title='Velocity', range=[player_df['release_speed'].min() - 2, player_df['release_speed'].max() + 2])
# fig.update_yaxes(range=[0, len(pitch_counts)+1-0.25], visible=False)
fig.update_yaxes(range=[0, len(pitch_counts)-0.25], visible=False)
fig.update_layout(
violingap=0,
violingroupgap=0,
legend=dict(orientation='h', y=-0.15, yanchor='top'),
modebar_remove=['zoom', 'select2d', 'lasso2d', 'pan', 'autoScale'],
dragmode=False
)
return fig
def update_dfs(player, handedness, df):
if handedness == 'Both':
handedness_filter = pl.col('stand').is_in(['R', 'L'])
_pitch_stats = pitch_stats
_league_pitch_stats = league_pitch_stats
elif handedness == 'Right':
handedness_filter = pl.col('stand') == 'R'
_pitch_stats = rhb_pitch_stats
_league_pitch_stats = rhb_league_pitch_stats
elif handedness == 'Left':
handedness_filter = pl.col('stand') == 'L'
_pitch_stats = lhb_pitch_stats
_league_pitch_stats = lhb_league_pitch_stats
player_filter = pl.col('name') == player
final_filter = player_filter & handedness_filter
return df.filter(final_filter), df.filter(handedness_filter), _pitch_stats.filter(player_filter), _league_pitch_stats,
def set_download_file(df):
file_path = 'files/npb.csv'
df.write_csv(file_path)
return file_path
def preview_df(df):
return df.head()
@clone_df
def plot_usage(df, player):
fig = px.pie(df.select('pitch_name'), names='pitch_name')
fig.update_traces(texttemplate='%{percent:.1%}', hovertemplate=f'{player}
' + 'threw a %{label}
%{percent:.1%} of the time (%{value} pitches)')
return fig
@clone_df
def plot_pitch_cards(df, league_df, pitch_stats):
pitch_counts = df['pitch_name'].value_counts().sort('count', descending=True)
pitch_rows = []
pitch_groups = []
pitch_names = []
pitch_infos = []
pitch_velos = []
pitch_locs = []
for row in range(ceil(len(pitch_counts) / LOCS_PER_ROW)):
pitch_rows.append(gr.update(visible=True))
for row in range(len(pitch_rows), MAX_ROWS):
pitch_rows.append(gr.update(visible=False))
for pitch_name, count in pitch_counts.iter_rows():
pitch_groups.append(gr.update(visible=True))
pitch_names.append(gr.update(value=f'### {pitch_name}', visible=True))
pitch_infos.append(gr.update(
value=pitch_stats.filter(pl.col('pitch_name') == pitch_name).select(['Whiff%', 'CSW%']),
visible=True
))
pitch_velos.append(gr.update(
value=plot_velo(velos=df.filter((pl.col('pitch_name') == pitch_name) & (pl.col('release_speed').is_not_null()))['release_speed']),
visible=True
))
pitch_locs.append(gr.update(
value=plot_loc(
df=df.filter(pl.col('pitch_name') == pitch_name),
league_df=league_df.filter(pl.col('pitch_name') == pitch_name)
),
label='Pitch location',
visible=True
))
for _ in range(max_pitch_types - len(pitch_names)):
pitch_groups.append(gr.update(visible=False))
pitch_names.append(gr.update(value=None, visible=False))
pitch_infos.append(gr.update(value=None, visible=False))
pitch_velos.append(gr.update(value=None, visible=False))
pitch_locs.append(gr.update(value=None, visible=False))
return pitch_rows + pitch_groups + pitch_names + pitch_infos + pitch_velos + pitch_locs
@clone_df
def update_velo_stats(pitch_stats, league_pitch_stats):
return (
pitch_stats
.select(pl.col('pitch_name').alias('Pitch'), pl.col('Velocity').alias('Avg. Velo'), pl.col('Count'))
.join(
league_pitch_stats.select(pl.col('pitch_name').alias('Pitch'), pl.col('Velocity').alias('League Avg. Velo')),
on='Pitch',
how='inner'
)
.sort('Count', descending=True)
.drop('Count')
)