Spaces:
Sleeping
Sleeping
| import langgraph | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_groq import ChatGroq | |
| from typing_extensions import TypedDict, Annotated | |
| from pydantic import BaseModel, Field | |
| from langchain_community.utilities import GoogleSerperAPIWrapper, WikipediaAPIWrapper | |
| from langchain.tools import GoogleSerperRun, WikipediaQueryRun, DuckDuckGoSearchRun | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_community.tools import TavilySearchResults | |
| from langgraph.graph.message import add_messages | |
| from dotenv import load_dotenv | |
| from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from langgraph.graph.message import AnyMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from pydantic import BaseModel,Field | |
| from fastapi import FastAPI, Response | |
| import uvicorn | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from datetime import timedelta | |
| import os | |
| import warnings | |
| import json | |
| import re | |
| warnings.filterwarnings("ignore") | |
| load_dotenv() | |
| app = FastAPI() | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY') | |
| os.environ["TAVILY_API_KEY"] = os.getenv("TAVILY_API_KEY") | |
| os.environ["SERPER_API_KEY"] = os.getenv("SERPER_API_KEY") | |
| llm = ChatGroq(model = "qwen-qwq-32b",temperature=0.1) | |
| memorysaver = MemorySaver() | |
| clearMemory = False | |
| class State(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| class MovieDetails(BaseModel): | |
| title: str = Field(..., title="Movie Title", description="The title of the programme for which you want to fetch details.") | |
| genres:str = Field(..., title="Genres", description="The genres of the programmee.") | |
| duration:str = Field(..., title="Duration", description="The duration of the programme is formatted in minutes ONLY. Donot include any range") | |
| synopsis:str = Field(..., title="Synopsis", description="The synopsis of the programme.") | |
| numberofSeasons: str = Field(..., title="NumberOfSeasons", description="The number of seasons in the programme.") | |
| numberOfEpisodes: str = Field(..., title="NumberOfEpisodes", description="The number of episodes in the programme.") | |
| summary:str = Field(..., title="Summary", description="The summary of the programme which contains number of episodes and seasons.") | |
| source: str = Field(..., title="Source", description="The source(url) from where the information is fetched.") | |
| parser = PydanticOutputParser(pydantic_object=MovieDetails) | |
| def build_tools(): | |
| serper_wrapper = GoogleSerperAPIWrapper(k = 1) | |
| serper_run = GoogleSerperRun(api_wrapper = serper_wrapper) | |
| tools = [TavilySearchResults(max_results=2)] | |
| return tools | |
| def get_llm(): | |
| global llm | |
| global clearMemory | |
| if clearMemory: | |
| llm = ChatGroq(model = "openai/gpt-oss-120b",temperature=0.1) | |
| tools = build_tools() | |
| llm_output = llm.with_structured_output(MovieDetails) | |
| llm_with_tools = llm.bind_tools(tools) | |
| return llm_with_tools | |
| def llm_callingTools(state:State): | |
| format_instructions = parser.get_format_instructions() | |
| system_msg = SystemMessage(content=f"""You are a smart movie researcher. | |
| 1. Your job is to retrieve **only real, verifiable details** from trusted sources. | |
| 2. **Never assume or generate** fake names, genres, or synopses, details about episodes or seasons. | |
| 3. Always provide a **brief summary** in **plain text** under the title. | |
| 4. Do not use markdown formatting like ```json or quotes around the entire JSON block in your answer. | |
| 5. Always provide meaningful information. Donot leave it blank | |
| ### Format Instructions: | |
| If the user **is asking about a specific show or programme** β for example, referencing the title, episodes, seasons, cast, language, or summary β format your response like this: | |
| {format_instructions} | |
| Otherwise, if the latest message does **not** refer to any specific programme or show (e.g. general queries), respond in **plain text** only without JSON formatting. | |
| Think carefully before responding: **Is the latest message is referring to a specific show or programme, even indirectly?** Only then use the formatted output.""") | |
| human_message = HumanMessage( content=f"{state['messages']}.") | |
| llm_with_tools = get_llm() | |
| return {"messages": [llm_with_tools.invoke([system_msg]+ [human_message])]} | |
| def build_graph(memory: bool = False): | |
| global memorysaver | |
| global clearMemory | |
| clearMemory = memory | |
| if clearMemory: | |
| memorysaver = MemorySaver() | |
| graph_builder = StateGraph(State) | |
| graph_builder.add_node("llm_with_tool", llm_callingTools) | |
| graph_builder.add_node("tools", ToolNode(build_tools())) | |
| graph_builder.add_edge(START, "llm_with_tool") | |
| graph_builder.add_conditional_edges("llm_with_tool", tools_condition) | |
| graph_builder.add_edge("tools", "llm_with_tool") | |
| graph = graph_builder.compile(checkpointer=memorysaver) | |
| return graph | |
| def is_pattern_in_string(string: str) -> bool: | |
| pattern = r'\bepisode?s?\b|\bseason?s?\b' | |
| return re.search(pattern, string) is not None | |
| def formatDuration(duration: str) -> str: | |
| try: | |
| duration = duration.replace("minutes", "").replace("mins", "").strip() | |
| td = timedelta(minutes= int(duration)) | |
| hours, remainder = divmod(td.seconds, 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| return f"{td.days * 24 + hours:02}:{minutes:02}:{seconds:02}" | |
| except Exception as e: | |
| return "N/A" | |
| def get_data_by_prompt(prompt: str, thread_id: str): | |
| global clearMemory | |
| clearMemory = False | |
| try: | |
| print(f"Prompt: {prompt}") | |
| if not is_pattern_in_string(prompt): | |
| print("No previous conversation found. Starting fresh.") | |
| clearMemory = True | |
| graph = build_graph(clearMemory) | |
| config = {"configurable": {"thread_id": thread_id}} | |
| message_prompt = {"messages": [{"role":"human", "content":prompt}]} | |
| data = graph.invoke(message_prompt, config=config) | |
| final_output = data["messages"][-1].content | |
| if is_pattern_in_string(prompt): | |
| try: | |
| final_output_new = json.loads(final_output) | |
| if isinstance(final_output_new, dict): | |
| return Response(content=final_output_new["summary"], media_type="text/markdown") | |
| else: | |
| return Response(content=final_output, media_type="text/markdown") | |
| except json.JSONDecodeError as e: | |
| return Response(content=final_output, media_type="text/markdown") | |
| programme_output = data["messages"][-1].content | |
| return Response(content = programme_output, media_type="text/markdown") | |
| except Exception as e: | |
| return Response(content=str(e), media_type="text/markdown") | |
| if __name__ == "__main__": | |
| #get_data_by_prompt("CSI","1") | |
| uvicorn.run(app, host= "127.0.0.1", port= 8000) | |