import langgraph from langgraph.graph import StateGraph, START, END from langgraph.prebuilt import ToolNode, tools_condition from langchain_groq import ChatGroq from typing_extensions import TypedDict, Annotated from pydantic import BaseModel, Field from langchain_community.utilities import GoogleSerperAPIWrapper, WikipediaAPIWrapper from langchain.tools import GoogleSerperRun, WikipediaQueryRun, DuckDuckGoSearchRun from langchain_core.messages import SystemMessage, HumanMessage from langchain_community.tools import TavilySearchResults from langgraph.graph.message import add_messages from dotenv import load_dotenv from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.output_parsers import PydanticOutputParser from langgraph.graph.message import AnyMessage from langgraph.checkpoint.memory import MemorySaver from pydantic import BaseModel,Field from fastapi import FastAPI, Response import uvicorn from fastapi.middleware.cors import CORSMiddleware import os import warnings import json import re warnings.filterwarnings("ignore") load_dotenv() app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY') os.environ["TAVILY_API_KEY"] = os.getenv("TAVILY_API_KEY") os.environ["SERPER_API_KEY"] = os.getenv("SERPER_API_KEY") llm = ChatGroq(model = "qwen-qwq-32b",temperature=0.1) memorysaver = MemorySaver() clearMemory = False class State(TypedDict): messages: Annotated[list[AnyMessage], add_messages] class MovieDetails(BaseModel): title: str = Field(..., title="Movie Title", description="The title of the programme for which you want to fetch details.") genres:str = Field(..., title="Genres", description="The genres of the programmee.") duration:str = Field(..., title="Duration", description="The duration of the programme.") synopsis:str = Field(..., title="Synopsis", description="The synopsis of the programme.") numberofSeasons: str = Field(..., title="NumberOfSeasons", description="The number of seasons in the programme.") numberOfEpisodes: str = Field(..., title="NumberOfEpisodes", description="The number of episodes in the programme.") summary:str = Field(..., title="Summary", description="The summary of the programme which contains number of episodes and seasons.") source: str = Field(..., title="Source", description="The source(url) from where the information is fetched.") parser = PydanticOutputParser(pydantic_object=MovieDetails) def build_tools(): serper_wrapper = GoogleSerperAPIWrapper(k = 1) serper_run = GoogleSerperRun(api_wrapper = serper_wrapper) tools = [TavilySearchResults(max_results=2)] return tools def get_llm(): global llm global clearMemory if clearMemory: llm = ChatGroq(model = "qwen-qwq-32b",temperature=0.1) tools = build_tools() llm_output = llm.with_structured_output(MovieDetails) llm_with_tools = llm.bind_tools(tools) return llm_with_tools def llm_callingTools(state:State): format_instructions = parser.get_format_instructions() system_msg = SystemMessage(content=f"""You are a smart movie researcher. 1. Your job is to retrieve **only real, verifiable details** from trusted sources. 2. **Never assume or generate** fake names, genres, or synopses, details about episodes or seasons. 3. Always provide a **brief summary** in **plain text** under the title. ### Format Instructions: If the user **is asking about a specific show or programme** — for example, referencing the title, episodes, seasons, cast, language, or summary — format your response like this: {format_instructions} Otherwise, if the latest message does **not** refer to any specific programme or show (e.g. general queries), respond in **plain text** only without JSON formatting. Think carefully before responding: **Is the latest message is referring to a specific show or programme, even indirectly?** Only then use the formatted output.""") human_message = HumanMessage( content=f"{state['messages']}.") llm_with_tools = get_llm() return {"messages": [llm_with_tools.invoke([system_msg]+ [human_message])]} def build_graph(memory: bool = False): global memorysaver global clearMemory clearMemory = memory if clearMemory: memorysaver = MemorySaver() graph_builder = StateGraph(State) graph_builder.add_node("llm_with_tool", llm_callingTools) graph_builder.add_node("tools", ToolNode(build_tools())) graph_builder.add_edge(START, "llm_with_tool") graph_builder.add_conditional_edges("llm_with_tool", tools_condition) graph_builder.add_edge("tools", "llm_with_tool") graph = graph_builder.compile(checkpointer=memorysaver) return graph def is_pattern_in_string(string: str) -> bool: pattern = r'\bepisode?s?\b|\bseason?s?\b' return re.search(pattern, string) is not None @app.post("/api/v1/get_programme_info") def get_data_by_prompt(prompt: str, thread_id: str): global clearMemory clearMemory = False try: print(f"Prompt: {prompt}") if not is_pattern_in_string(prompt): print("No previous conversation found. Starting fresh.") clearMemory = True graph = build_graph(clearMemory) config = {"configurable": {"thread_id": thread_id}} message_prompt = {"messages": [{"role":"human", "content":prompt}]} data = graph.invoke(message_prompt, config=config) final_output = data["messages"][-1].content if is_pattern_in_string(prompt): try: final_output_new = json.loads(final_output) if isinstance(final_output_new, dict): return Response(content=final_output_new["summary"], media_type="text/markdown") else: return Response(content=final_output, media_type="text/markdown") except json.JSONDecodeError as e: return Response(content=final_output, media_type="text/markdown") return Response(content = data["messages"][-1].content, media_type="text/markdown") except Exception as e: return Response(content=str(e), media_type="text/markdown") if __name__ == "__main__": #get_data_by_prompt("CSI","1") uvicorn.run(app, host= "127.0.0.1", port= 8000)