Spaces:
Sleeping
Sleeping
import pandas as pd | |
import tensorflow as tf | |
import pickle | |
import plotly.express as px | |
import os | |
import numpy as np | |
from stock_and import GetNewData | |
class Model: | |
def __init__(self, stock_name, model_name) -> None: | |
"""Этот класс делает прогнозирования, выводит графики, рассчитывает важность признаков.""" | |
self.stock_name = stock_name # Название | |
self.features = ['lag_25', 'lag_34','lag_33','lag_26','lag_32','lag_31','lag_30','lag_29','lag_27','sentiment_neutral', 'lag_28', | |
'sentiment_positive','sentiment_negative', 'month','day'] # Фичи для модели | |
self.model_name = model_name # Название модели | |
# Загрузка моделей | |
if model_name == 'NN': | |
self.model = tf.keras.models.load_model(f'models/nn_predict_1day_ver2_{stock_name}.h5', custom_objects={'mae': tf.keras.metrics.MeanAbsoluteError()}) | |
if model_name == 'LinearRegression': | |
with open(os.path.join('models', f'linear_predict_1day_ver2_{stock_name}.pkl'), 'rb') as f: | |
self.model = pickle.load(f) | |
if model_name == 'LGB': | |
with open(os.path.join('models', f'lgb_predict_1day_ver2_{stock_name}.pkl'), 'rb') as f: | |
self.model = pickle.load(f) | |
def generate_dataset(self, stock_name, num_day): | |
"""Парсим датасет, добавляем новые row для прогонза""" | |
merged_df, string = GetNewData(stock_name).get_full_data() # Тут мы получаем датасет с новостями и акциями | |
# Создаем датасет с дополнительными строками | |
last_date = merged_df['DATE'].max() | |
new_date_rng = pd.date_range(start=last_date + pd.Timedelta(hours=1), periods=24 * num_day , freq='H') | |
new_df = pd.DataFrame(new_date_rng, columns=['DATE']) | |
new_df['month'] = new_df['DATE'].dt.month | |
new_df['day'] = new_df['DATE'].dt.day | |
for c in self.features: | |
last_value = merged_df[c].values[-24 * num_day :] | |
new_df[c] = last_value | |
# Объединяем датасет | |
merged_df_new = pd.concat([merged_df, new_df[self.features+['DATE']]], ignore_index=True) | |
return merged_df_new, string | |
def predict(self, num_day): | |
# Создаем прогнозы, важности признаков | |
merged_df, string = self.generate_dataset(self.stock_name, num_day) # Парсим датасет | |
if self.model_name == 'NN': | |
importance = np.abs(self.model.layers[0].get_weights()[0]) | |
importance = np.mean(importance, axis=1) | |
df_weighted = merged_df[self.features].ffill().bfill()[-num_day*24:] * importance | |
if self.model_name == 'LinearRegression': | |
df_weighted = merged_df[self.features].ffill().bfill()[-num_day*24:] * self.model.coef_ | |
if self.model_name == 'LGB': | |
df_weighted = merged_df[self.features].ffill().bfill()[-num_day*24:] * self.model.feature_importances_ | |
average_values = df_weighted.mean(axis=0).abs().sort_values(ascending=False) | |
average_values_filtered = average_values.drop('lag_25') | |
total_sum = average_values_filtered.sum() | |
average_values_percentage = (average_values_filtered / total_sum) * 100 | |
string += '\n Самые полезные признаки для прогнозов: \n' | |
for f, v in zip(average_values_percentage.index, average_values_percentage.values): | |
string += f'- {f}: важность = {v:.2f}%\n' | |
if self.model_name in ['LinearRegression', 'LGB']: | |
return pd.DataFrame({ | |
'predict': self.model.predict(merged_df[self.features].ffill().bfill().values), | |
'DATE': merged_df['DATE'].values, | |
'CLOSE': merged_df['CLOSE'].values | |
}), string | |
else: | |
return pd.DataFrame({ | |
'predict': self.model.predict(merged_df[self.features].ffill().bfill().values)[:, 0], | |
'DATE': merged_df['DATE'].values, | |
'CLOSE': merged_df['CLOSE'].values | |
}), string | |
def plot_predict(self, predict, add_smoothing): | |
predict = predict[-24*12:] | |
scaling_factor = predict['CLOSE'].mean() / predict['predict'].mean() | |
scaled_preds = predict['predict'] * scaling_factor | |
fig = px.line(predict, x=predict.DATE, y='CLOSE', labels={'value': 'Цена'}, title='CLOSE') | |
fig.add_scatter(x=predict.DATE, y=scaled_preds, mode='lines', name='Predict', opacity=0.7) | |
if add_smoothing: | |
smoothed_preds = pd.Series(scaled_preds).ewm(3).mean() | |
fig.add_scatter(x=predict.DATE, y=smoothed_preds, mode='lines', name='Сглаженные предсказания', opacity=0.7) | |
fig.update_layout(xaxis=dict(type='category')) | |
return fig |