Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import yfinance as yf | |
| import requests | |
| import os | |
| from dotenv import load_dotenv | |
| from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser | |
| from langchain.schema import AgentAction, AgentFinish, HumanMessage | |
| from langchain.prompts import BaseChatPromptTemplate | |
| from langchain.tools import Tool | |
| from langchain_huggingface import HuggingFacePipeline | |
| from langchain import LLMChain | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from statsmodels.tsa.arima.model import ARIMA | |
| import torch | |
| import re | |
| from typing import List, Union | |
| from datetime import datetime | |
| from lumibot.brokers import Alpaca | |
| from lumibot.backtesting import YahooDataBacktesting | |
| from lumibot.strategies.strategy import Strategy | |
| from alpaca_trade_api import REST | |
| from timedelta import Timedelta | |
| from finbert_utils import estimate_sentiment | |
| # Load environment variables from .env | |
| load_dotenv() | |
| NEWSAPI_KEY = os.getenv("NEWSAPI_KEY") | |
| access_token = os.getenv("API_KEY") | |
| # Check if the access token and API key are present | |
| if not NEWSAPI_KEY or not access_token: | |
| raise ValueError("NEWSAPI_KEY or API_KEY not found in .env file.") | |
| # Alpaca credentials | |
| API_KEY = "PKWJW14IWRJMLJ4CSZ6V" | |
| API_SECRET = "zJOGwUvhYBfYJQRz6jc309PLNfTQ4VcxuygFxxfh" | |
| BASE_URL = "https://paper-api.alpaca.markets/v2" | |
| ALPACA_CREDS = { | |
| "API_KEY": API_KEY, | |
| "API_SECRET": API_SECRET, | |
| "PAPER": True | |
| } | |
| # Initialize the model and tokenizer for the HuggingFace pipeline | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-2b-it", | |
| torch_dtype=torch.bfloat16, | |
| token=access_token | |
| ) | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512) | |
| # Define functions for fetching stock data, news, and moving averages | |
| def validate_ticker(ticker): | |
| return ticker.strip().upper() | |
| def fetch_stock_data(ticker): | |
| try: | |
| ticker = ticker.strip().upper() | |
| stock = yf.Ticker(ticker) | |
| hist = stock.history(period="1mo") | |
| if hist.empty: | |
| return {"error": f"No data found for ticker {ticker}"} | |
| return hist.tail(5).to_dict() | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def fetch_stock_news(ticker, NEWSAPI_KEY): | |
| api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}" | |
| response = requests.get(api_url) | |
| if response.status_code == 200: | |
| articles = response.json().get('articles', []) | |
| return [{"title": article['title'], "description": article['description']} for article in articles[:5]] | |
| else: | |
| return [{"error": "Unable to fetch news."}] | |
| def calculate_moving_average(ticker, window=5): | |
| stock = yf.Ticker(ticker) | |
| hist = stock.history(period="1mo") | |
| hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean() | |
| return hist[["Close", f"{window}-day MA"]].tail(5) | |
| def analyze_sentiment(news_articles): | |
| sentiment_pipeline = pipeline("sentiment-analysis") | |
| results = [{"title": article["title"], | |
| "sentiment": sentiment_pipeline(article["description"] or article["title"])[0]} | |
| for article in news_articles] | |
| return results | |
| def predict_stock_price(ticker, days=5): | |
| stock = yf.Ticker(ticker) | |
| hist = stock.history(period="6mo") | |
| if hist.empty: | |
| return {"error": f"No data found for ticker {ticker}"} | |
| model = ARIMA(hist["Close"], order=(5, 1, 0)) | |
| model_fit = model.fit() | |
| forecast = model_fit.forecast(steps=days) | |
| return forecast.tolist() | |
| def compare_stocks(ticker1, ticker2): | |
| data1 = fetch_stock_data(ticker1) | |
| data2 = fetch_stock_data(ticker2) | |
| if "error" in data1 or "error" in data2: | |
| return {"error": "Could not fetch stock data for comparison."} | |
| comparison = { | |
| ticker1: {"recent_close": data1["Close"][-1]}, | |
| ticker2: {"recent_close": data2["Close"][-1]}, | |
| } | |
| return comparison | |
| def execute_alpaca_trading(): | |
| class MLTrader(Strategy): | |
| def initialize(self, symbol: str = "SPY", cash_at_risk: float = .5): | |
| self.symbol = symbol | |
| self.sleeptime = "24H" | |
| self.last_trade = None | |
| self.cash_at_risk = cash_at_risk | |
| self.api = REST(base_url=BASE_URL, key_id=API_KEY, secret_key=API_SECRET) | |
| def position_sizing(self): | |
| cash = self.get_cash() | |
| last_price = self.get_last_price(self.symbol) | |
| quantity = round(cash * self.cash_at_risk / last_price, 0) | |
| return cash, last_price, quantity | |
| def get_dates(self): | |
| today = self.get_datetime() | |
| three_days_prior = today - Timedelta(days=3) | |
| return today.strftime('%Y-%m-%d'), three_days_prior.strftime('%Y-%m-%d') | |
| def get_sentiment(self): | |
| today, three_days_prior = self.get_dates() | |
| news = self.api.get_news(symbol=self.symbol, | |
| start=three_days_prior, | |
| end=today) | |
| news = [ev.__dict__["_raw"]["headline"] for ev in news] | |
| probability, sentiment = estimate_sentiment(news) | |
| return probability, sentiment | |
| def on_trading_iteration(self): | |
| cash, last_price, quantity = self.position_sizing() | |
| probability, sentiment = self.get_sentiment() | |
| if cash > last_price: | |
| if sentiment == "positive" and probability > .999: | |
| if self.last_trade == "sell": | |
| self.sell_all() | |
| order = self.create_order( | |
| self.symbol, | |
| quantity, | |
| "buy", | |
| type="bracket", | |
| take_profit_price=last_price * 1.20, | |
| stop_loss_price=last_price * .95 | |
| ) | |
| self.submit_order(order) | |
| self.last_trade = "buy" | |
| elif sentiment == "negative" and probability > .999: | |
| if self.last_trade == "buy": | |
| self.sell_all() | |
| order = self.create_order( | |
| self.symbol, | |
| quantity, | |
| "sell", | |
| type="bracket", | |
| take_profit_price=last_price * .8, | |
| stop_loss_price=last_price * 1.05 | |
| ) | |
| self.submit_order(order) | |
| self.last_trade = "sell" | |
| start_date = datetime(2021, 1, 1) | |
| end_date = datetime(2024, 10, 1) | |
| broker = Alpaca(ALPACA_CREDS) | |
| strategy = MLTrader(name='mlstrat', broker=broker, | |
| parameters={"symbol": "SPY", | |
| "cash_at_risk": .5}) | |
| strategy.backtest( | |
| YahooDataBacktesting, | |
| start_date, | |
| end_date, | |
| parameters={"symbol": "SPY", "cash_at_risk": .5} | |
| ) | |
| return "Alpaca trading strategy executed and backtested." | |
| # Define LangChain tools | |
| stock_data_tool = Tool( | |
| name="Stock Data Fetcher", | |
| func=fetch_stock_data, | |
| description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)." | |
| ) | |
| stock_news_tool = Tool( | |
| name="Stock News Fetcher", | |
| func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY), | |
| description="Fetch recent news articles about a stock ticker." | |
| ) | |
| moving_average_tool = Tool( | |
| name="Moving Average Calculator", | |
| func=calculate_moving_average, | |
| description="Calculate the moving average of a stock over a 5-day window." | |
| ) | |
| sentiment_tool = Tool( | |
| name="News Sentiment Analyzer", | |
| func=lambda ticker: analyze_sentiment(fetch_stock_news(ticker, NEWSAPI_KEY)), | |
| description="Analyze the sentiment of recent news articles about a stock ticker." | |
| ) | |
| stock_prediction_tool = Tool( | |
| name="Stock Price Predictor", | |
| func=predict_stock_price, | |
| description="Predict future stock prices for a given ticker based on historical data." | |
| ) | |
| stock_comparator_tool = Tool( | |
| name="Stock Comparator", | |
| func=lambda tickers: compare_stocks(*tickers.split(',')), | |
| description="Compare the recent performance of two stocks given their tickers, e.g., 'AAPL,MSFT'." | |
| ) | |
| alpaca_trading_tool = Tool( | |
| name="Alpaca Trading Executor", | |
| func=execute_alpaca_trading, | |
| description="Run a trading strategy using Alpaca API and backtest results." | |
| ) | |
| tools = [ | |
| stock_data_tool, | |
| stock_news_tool, | |
| moving_average_tool, | |
| sentiment_tool, | |
| stock_prediction_tool, | |
| stock_comparator_tool, | |
| alpaca_trading_tool | |
| ] | |
| # Set up a prompt template with history | |
| template_with_history = """You are SearchGPT, a professional search engine who provides informative answers to users. Answer the following questions as best you can. You have access to the following tools: | |
| {tools} | |
| Use the following format: | |
| Question: the input question you must answer | |
| Thought: you should always think about what to do | |
| Action: the action to take, should be one of [{tool_names}] | |
| Action Input: the input to the action | |
| Observation: the result of the action | |
| (this Thought/Action/Action Input/Observation can repeat N times) | |
| Thought: I now know the final answer | |
| Final Answer: the final answer to the original input question | |
| Begin! Remember to give detailed, informative answers | |
| Previous conversation history: | |
| {history} | |
| New question: {input} | |
| {agent_scratchpad}""" | |
| # Set up the prompt template | |
| class CustomPromptTemplate(BaseChatPromptTemplate): | |
| template: str | |
| tools: List[Tool] | |
| def format_messages(self, **kwargs) -> str: | |
| intermediate_steps = kwargs.pop("intermediate_steps") | |
| thoughts = "" | |
| for action, observation in intermediate_steps: | |
| thoughts += action.log | |
| thoughts += f"\nObservation: {observation}\nThought: " | |
| kwargs["agent_scratchpad"] = thoughts | |
| kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) | |
| kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) | |
| formatted = self.template.format(**kwargs) | |
| return [HumanMessage(content=formatted)] | |
| prompt_with_history = CustomPromptTemplate( | |
| template=template_with_history, | |
| tools=tools, | |
| input_variables=["input", "intermediate_steps", "history"] | |
| ) | |
| # Custom output parser | |
| class CustomOutputParser(AgentOutputParser): | |
| def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: | |
| if "Final Answer:" in llm_output: | |
| return AgentFinish( | |
| return_values={"output": llm_output.split("Final Answer:")[-1].strip()}, | |
| log=llm_output, | |
| ) | |
| regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)" | |
| match = re.search(regex, llm_output, re.DOTALL) | |
| if not match: | |
| raise ValueError(f"Could not parse LLM output: `{llm_output}`") | |
| action = match.group(1).strip() | |
| action_input = match.group(2) | |
| return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) | |
| output_parser = CustomOutputParser() | |
| # Initialize HuggingFace pipeline | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # LLM chain | |
| llm_chain = LLMChain(llm=llm, prompt=prompt_with_history) | |
| tool_names = [tool.name for tool in tools] | |
| agent = LLMSingleActionAgent( | |
| llm_chain=llm_chain, | |
| output_parser=output_parser, | |
| stop=["\nObservation:"], | |
| allowed_tools=tool_names | |
| ) | |
| memory = ConversationBufferWindowMemory(k=2) | |
| agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory) | |
| # Streamlit app | |
| st.title("Trading Helper Agent") | |
| query = st.text_input("Enter your query:") | |
| if st.button("Submit"): | |
| if query: | |
| st.write("Debug: User Query ->", query) | |
| with st.spinner("Processing..."): | |
| try: | |
| # Run the agent and get the response | |
| response = agent_executor.run(query) # Correct method is `run()` | |
| st.success("Response:") | |
| st.write(response) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| # Log the full LLM | |