Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
from peft import PeftModel | |
import re | |
import os | |
import akshare as ak | |
import pandas as pd | |
import random | |
import json | |
import requests | |
import math | |
from datetime import date | |
from datetime import date, datetime, timedelta | |
access_token = os.environ["TOKEN"] | |
# load model | |
model = "meta-llama/Llama-2-7b-chat-hf" | |
peft_model = "FinGPT/fingpt-forecaster_sz50_llama2-7B_lora" | |
tokenizer = AutoTokenizer.from_pretrained(model, token = access_token, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
streamer = TextStreamer(tokenizer) | |
model = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True, token = access_token, device_map="cuda", load_in_8bit=True, offload_folder="offload/") | |
model = PeftModel.from_pretrained(model, peft_model, offload_folder="offload/") | |
model = model.eval() | |
# Inference Data | |
# get company news online | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
SYSTEM_PROMPT = "你是一名经验丰富的股票市场分析师。你的任务是根据公司在过去几周内的相关新闻和季度财务状况,列出公司的积极发展和潜在担忧,然后结合你对整体金融经济市场的判断,对公司未来一周的股价变化提供预测和分析。" \ | |
"你的回答语言应为中文。你的回答格式应该如下:\n\n[积极发展]:\n1. ...\n\n[潜在担忧]:\n1. ...\n\n[预测和分析]:\n...\n" | |
# ------------------------------------------------------------------------------ | |
# Utils | |
# ------------------------------------------------------------------------------ | |
def get_curday(): | |
return date.today().strftime("%Y%m%d") | |
def n_weeks_before(date_string, n, format = "%Y%m%d"): | |
date = datetime.strptime(date_string, "%Y%m%d") - timedelta(days=7*n) | |
return date.strftime(format=format) | |
def check_news_quality(n, last_n, week_end_date, repeat_rate = 0.6): | |
try: | |
# check content avalability | |
if not (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')): | |
return False | |
# check highly duplicated news | |
# (assume the duplicated contents happened adjacent) | |
elif str(last_n['新闻内容'])=='nan': | |
return True | |
elif len(set(n['新闻内容'][:20]) & set(last_n['新闻内容'][:20])) >= 20*repeat_rate or len(set(n['新闻标题']) & set(last_n['新闻标题']))/len(last_n['新闻标题']) > repeat_rate: | |
return False | |
else: | |
return True | |
except TypeError: | |
print(n) | |
print(last_n) | |
raise Exception("Check Error") | |
def sample_news(news, k=5): | |
return [news[i] for i in sorted(random.sample(range(len(news)), k))] | |
def return_transform(ret): | |
up_down = '涨' if ret >= 0 else '跌' | |
integer = math.ceil(abs(100 * ret)) | |
if integer == 0: | |
return "平" | |
return up_down + (str(integer) if integer <= 5 else '5+') | |
def map_return_label(return_lb): | |
lb = return_lb.replace('涨', '上涨') | |
lb = lb.replace('跌', '下跌') | |
lb = lb.replace('平', '股价持平') | |
lb = lb.replace('1', '0-1%') | |
lb = lb.replace('2', '1-2%') | |
lb = lb.replace('3', '2-3%') | |
lb = lb.replace('4', '3-4%') | |
if lb.endswith('+'): | |
lb = lb.replace('5+', '超过5%') | |
else: | |
lb = lb.replace('5', '4-5%') | |
return lb | |
# ------------------------------------------------------------------------------ | |
# Get data from website | |
# ------------------------------------------------------------------------------ | |
def stock_news_em(symbol: str = "300059", page = 1) -> pd.DataFrame: | |
url = "https://search-api-web.eastmoney.com/search/jsonp" | |
params = { | |
"cb": "jQuery3510875346244069884_1668256937995", | |
"param": '{"uid":"",' | |
+ f'"keyword":"{symbol}"' | |
+ ',"type":["cmsArticleWebOld"],"client":"web","clientType":"web","clientVersion":"curr","param":{"cmsArticleWebOld":{"searchScope":"default","sort":"default",' + f'"pageIndex":{page}'+ ',"pageSize":100,"preTag":"<em>","postTag":"</em>"}}}', | |
"_": "1668256937996", | |
} | |
r = requests.get(url, params=params) | |
data_text = r.text | |
data_json = json.loads( | |
data_text.strip("jQuery3510875346244069884_1668256937995(")[:-1] | |
) | |
temp_df = pd.DataFrame(data_json["result"]["cmsArticleWebOld"]) | |
temp_df.rename( | |
columns={ | |
"date": "发布时间", | |
"mediaName": "文章来源", | |
"code": "-", | |
"title": "新闻标题", | |
"content": "新闻内容", | |
"url": "新闻链接", | |
"image": "-", | |
}, | |
inplace=True, | |
) | |
temp_df["关键词"] = symbol | |
temp_df = temp_df[ | |
[ | |
"关键词", | |
"新闻标题", | |
"新闻内容", | |
"发布时间", | |
"文章来源", | |
"新闻链接", | |
] | |
] | |
temp_df["新闻标题"] = ( | |
temp_df["新闻标题"] | |
.str.replace(r"\(<em>", "", regex=True) | |
.str.replace(r"</em>\)", "", regex=True) | |
) | |
temp_df["新闻标题"] = ( | |
temp_df["新闻标题"] | |
.str.replace(r"<em>", "", regex=True) | |
.str.replace(r"</em>", "", regex=True) | |
) | |
temp_df["新闻内容"] = ( | |
temp_df["新闻内容"] | |
.str.replace(r"\(<em>", "", regex=True) | |
.str.replace(r"</em>\)", "", regex=True) | |
) | |
temp_df["新闻内容"] = ( | |
temp_df["新闻内容"] | |
.str.replace(r"<em>", "", regex=True) | |
.str.replace(r"</em>", "", regex=True) | |
) | |
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\u3000", "", regex=True) | |
temp_df["新闻内容"] = temp_df["新闻内容"].str.replace(r"\r\n", " ", regex=True) | |
return temp_df | |
def get_news(symbol, max_page = 3): | |
df_list = [] | |
for page in range(1, max_page): | |
try: | |
df_list.append(stock_news_em(symbol, page)) | |
except KeyError: | |
print(str(symbol) + "pages obtained for symbol: " + page) | |
break | |
news_df = pd.concat(df_list, ignore_index=True) | |
return news_df | |
def get_cur_return(symbol, start_date, end_date, adjust="qfq"): | |
# load data | |
return_data = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust=adjust) | |
# process timestamp | |
return_data["日期"] = pd.to_datetime(return_data["日期"]) | |
return_data.set_index("日期", inplace=True) | |
# resample and filled with forward data | |
weekly_data = return_data["收盘"].resample("W").ffill() | |
weekly_returns = weekly_data.pct_change()[1:] | |
weekly_start_prices = weekly_data[:-1] | |
weekly_end_prices = weekly_data[1:] | |
weekly_data = pd.DataFrame({ | |
'起始日期': weekly_start_prices.index, | |
'起始价': weekly_start_prices.values, | |
'结算日期': weekly_end_prices.index, | |
'结算价': weekly_end_prices.values, | |
'周收益': weekly_returns.values | |
}) | |
weekly_data["简化周收益"] = weekly_data["周收益"].map(return_transform) | |
# check enddate | |
if weekly_data.iloc[-1, 2] > pd.to_datetime(end_date): | |
weekly_data.iloc[-1, 2] = pd.to_datetime(end_date) | |
return weekly_data | |
def get_basic(symbol, data): | |
key_financials = ['报告期', '净利润同比增长率', '营业总收入同比增长率', '流动比率', '速动比率', '资产负债率'] | |
# load quarterly basic data | |
basic_quarter_financials = ak.stock_financial_abstract_ths(symbol = symbol, indicator="按单季度") | |
basic_fin_dict = basic_quarter_financials.to_dict("index") | |
basic_fin_list = [dict([(key, val) for key, val in basic_fin_dict[i].items() if (key in key_financials) and val]) for i in range(len(basic_fin_dict))] | |
# match basic financial data to news dataframe | |
matched_basic_fin = [] | |
for i, row in data.iterrows(): | |
newsweek_enddate = row['结算日期'].strftime("%Y-%m-%d") | |
matched_basic = {} | |
for basic in basic_fin_list: | |
# match the most current financial report | |
if basic["报告期"] < newsweek_enddate: | |
matched_basic = basic | |
break | |
matched_basic_fin.append(json.dumps(matched_basic, ensure_ascii=False)) | |
data['基本面'] = matched_basic_fin | |
return data | |
# ------------------------------------------------------------------------------ | |
# Structure Data | |
# ------------------------------------------------------------------------------ | |
def cur_financial_data(symbol, start_date, end_date, with_basics = True): | |
# get data | |
data = get_cur_return(symbol=symbol, start_date=start_date, end_date=end_date) | |
news_df = get_news(symbol=symbol) | |
news_df["发布时间"] = pd.to_datetime(news_df["发布时间"], exact=False, format="%Y-%m-%d") | |
news_df.sort_values(by=["发布时间"], inplace=True) | |
# match weekly news for return data | |
news_list = [] | |
for a, row in data.iterrows(): | |
week_start_date = row['起始日期'].strftime('%Y-%m-%d') | |
week_end_date = row['结算日期'].strftime('%Y-%m-%d') | |
print(symbol, ': ', week_start_date, ' - ', week_end_date) | |
weekly_news = news_df.loc[(news_df["发布时间"]>week_start_date) & (news_df["发布时间"]<week_end_date)] | |
weekly_news = [ | |
{ | |
"发布时间": n["发布时间"].strftime('%Y%m%d'), | |
"新闻标题": n['新闻标题'], | |
"新闻内容": n['新闻内容'], | |
} for a, n in weekly_news.iterrows() | |
] | |
news_list.append(json.dumps(weekly_news,ensure_ascii=False)) | |
data["新闻"] = news_list | |
if with_basics: | |
data = get_basic(symbol=symbol, data=data) | |
# data.to_csv(symbol+start_date+"_"+end_date+".csv") | |
else: | |
data['新闻'] = [json.dumps({})] * len(data) | |
# data.to_csv(symbol+start_date+"_"+end_date+"_nobasics.csv") | |
return data | |
# ------------------------------------------------------------------------------ | |
# Formate Instruction | |
# ------------------------------------------------------------------------------ | |
def get_company_prompt_new(symbol): | |
try: | |
company_profile = dict(ak.stock_individual_info_em(symbol).values) | |
except: | |
print("Company Info Request Time Out! Please wait and retry.") | |
company_profile["上市时间"] = pd.to_datetime(str(company_profile["上市时间"])).strftime("%Y年%m月%d日") | |
template = "[公司介绍]:\n\n{股票简称}是一家在{行业}行业的领先实体,自{上市时间}成立并公开交易。截止今天,{股票简称}的总市值为{总市值}人民币,总股本数为{总股本},流通市值为{流通市值}人民币,流通股数为{流通股}。" \ | |
"\n\n{股票简称}主要在中国运营,以股票代码{股票代码}在交易所进行交易。" | |
formatted_profile = template.format(**company_profile) | |
stockname = company_profile['股票简称'] | |
return formatted_profile, stockname | |
def get_prompt_by_row_new(stock, row): | |
week_start_date = row['起始日期'] if isinstance(row['起始日期'], str) else row['起始日期'].strftime('%Y-%m-%d') | |
week_end_date = row['结算日期'] if isinstance(row['结算日期'], str) else row['结算日期'].strftime('%Y-%m-%d') | |
term = '上涨' if row['结算价'] > row['起始价'] else '下跌' | |
chg = map_return_label(row['简化周收益']) | |
head = "自{}至{},{}的股票价格由{:.2f}{}至{:.2f},涨跌幅为:{}。在此期间的公司新闻如下:\n\n".format( | |
week_start_date, week_end_date, stock, row['起始价'], term, row['结算价'], chg) | |
news = json.loads(row["新闻"]) | |
left, right = 0, 0 | |
filtered_news = [] | |
while left < len(news): | |
n = news[left] | |
if left == 0: | |
# check first news quality | |
if (not(str(n['新闻内容'])[0].isdigit()) and not(str(n['新闻内容'])=='nan') and n['发布时间'][:8] <= week_end_date.replace('-', '')): | |
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容'])) | |
left += 1 | |
else: | |
news_check = check_news_quality(n, last_n = news[right], week_end_date= week_end_date, repeat_rate=0.5) | |
if news_check: | |
filtered_news.append("[新闻标题]:{}\n[新闻内容]:{}\n".format(n['新闻标题'], n['新闻内容'])) | |
left += 1 | |
right += 1 | |
basics = json.loads(row['基本面']) | |
if basics: | |
basics = "如下所列为{}近期的一些金融基本面信息,记录时间为{}:\n\n[金融基本面]:\n\n".format( | |
stock, basics['报告期']) + "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period') | |
else: | |
basics = "[金融基本面]:\n\n 无金融基本面记录" | |
return head, filtered_news, basics | |
def get_all_prompts_online(symbol, with_basics=True, max_news_perweek = 3, weeks_before = 2): | |
end_date = get_curday() | |
start_date = n_weeks_before(end_date, weeks_before) | |
company_prompt, stock = get_company_prompt_new(symbol) | |
data = cur_financial_data(symbol=symbol, start_date=start_date, end_date=end_date, with_basics=with_basics) | |
prev_rows = [] | |
for row_idx, row in data.iterrows(): | |
head, news, basics = get_prompt_by_row_new(symbol, row) | |
prev_rows.append((head, news, basics)) | |
prompt = "" | |
for i in range(-len(prev_rows), 0): | |
prompt += "\n" + prev_rows[i][0] | |
sampled_news = sample_news( | |
prev_rows[i][1], | |
min(max_news_perweek, len(prev_rows[i][1])) | |
) | |
if sampled_news: | |
prompt += "\n".join(sampled_news) | |
else: | |
prompt += "No relative news reported." | |
next_date = n_weeks_before(end_date, -1, format="%Y-%m-%d") | |
end_date = pd.to_datetime(end_date).strftime("%Y-%m-%d") | |
period = "{}至{}".format(end_date, next_date) | |
if with_basics: | |
basics = prev_rows[-1][2] | |
else: | |
basics = "[金融基本面]:\n\n 无金融基本面记录" | |
info = company_prompt + '\n' + prompt + '\n' + basics | |
new_system_prompt = SYSTEM_PROMPT.replace(':\n...', ':\n预测涨跌幅:...\n总结分析:...') | |
prompt = B_INST + B_SYS + new_system_prompt + E_SYS + info + f"\n\n基于在{end_date}之前的所有信息,让我们首先分析{stock}的积极发展和潜在担忧。请简洁地陈述,分别提出2-4个最重要的因素。大部分所提及的因素应该从公司的相关新闻中推断出来。" \ | |
f"接下来请预测{symbol}下周({period})的股票涨跌幅,并提供一个总结分析来支持你的预测。" + E_INST | |
del prev_rows | |
del data | |
return info, prompt | |
def ask(symbol, weeks_before): | |
# load inference data | |
info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before) | |
# print(info) | |
inputs = tokenizer(pt, return_tensors='pt') | |
inputs = {key: value.to(model.device) for key, value in inputs.items()} | |
print("Inputs loaded onto devices.") | |
res = model.generate( | |
**inputs, | |
use_cache=True, | |
max_length = 4096, | |
streamer=streamer | |
) | |
output = tokenizer.decode(res[0], skip_special_tokens=True) | |
output_cur = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) | |
return info, output_cur | |
server = gr.Interface( | |
ask, | |
inputs=[ | |
gr.Textbox( | |
label="Symbol", | |
value="600519", | |
info="Companys from SZ50 are recommended" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=3, | |
value=2, | |
step=1, | |
label="weeks_before", | |
info="Due to the token length constraint, you are recommended to input with 2" | |
), | |
], | |
outputs=[ | |
gr.Textbox( | |
label="Information" | |
), | |
gr.Textbox( | |
label="Response" | |
) | |
], | |
title="FinGPT-Forecaster-Chinese", | |
description="""This version allows the prediction based on the most current date. We will upgrade it to allow customized date soon.""" | |
) | |
server.launch() |