Spaces:
Sleeping
Sleeping
import langgraph | |
from langgraph.graph import StateGraph, START, END | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_groq import ChatGroq | |
from typing_extensions import TypedDict, Annotated | |
from pydantic import BaseModel, Field | |
from langchain_community.utilities import GoogleSerperAPIWrapper, WikipediaAPIWrapper | |
from langchain.tools import GoogleSerperRun, WikipediaQueryRun, DuckDuckGoSearchRun | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_community.tools import TavilySearchResults | |
from langgraph.graph.message import add_messages | |
from dotenv import load_dotenv | |
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain_core.output_parsers import PydanticOutputParser | |
from langgraph.graph.message import AnyMessage | |
from langgraph.checkpoint.memory import MemorySaver | |
from pydantic import BaseModel,Field | |
from fastapi import FastAPI, Response | |
import uvicorn | |
from fastapi.middleware.cors import CORSMiddleware | |
import os | |
import warnings | |
import json | |
import re | |
warnings.filterwarnings("ignore") | |
load_dotenv() | |
app = FastAPI() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY') | |
os.environ["TAVILY_API_KEY"] = os.getenv("TAVILY_API_KEY") | |
os.environ["SERPER_API_KEY"] = os.getenv("SERPER_API_KEY") | |
llm = ChatGroq(model = "qwen-qwq-32b",temperature=0.1) | |
memorysaver = MemorySaver() | |
class State(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
class MovieDetails(BaseModel): | |
title: str = Field(..., title="Movie Title", description="The title of the programme for which you want to fetch details.") | |
genres:str = Field(..., title="Genres", description="The genres of the programmee.") | |
duration:str = Field(..., title="Duration", description="The duration of the programme.") | |
synopsis:str = Field(..., title="Synopsis", description="The synopsis of the programme.") | |
numberofSeasons: str = Field(..., title="NumberOfSeasons", description="The number of seasons in the programme.") | |
numberOfEpisodes: str = Field(..., title="NumberOfEpisodes", description="The number of episodes in the programme.") | |
summary:str = Field(..., title="Summary", description="The summary of the programme which contains number of episodes and seasons.") | |
source: str = Field(..., title="Source", description="The source(url) from where the information is fetched.") | |
parser = PydanticOutputParser(pydantic_object=MovieDetails) | |
def build_tools(): | |
serper_wrapper = GoogleSerperAPIWrapper(k = 1) | |
serper_run = GoogleSerperRun(api_wrapper = serper_wrapper) | |
tools = [TavilySearchResults(max_results=2)] | |
return tools | |
def get_llm(): | |
tools = build_tools() | |
llm_output = llm.with_structured_output(MovieDetails) | |
llm_with_tools = llm.bind_tools(tools) | |
return llm_with_tools | |
def llm_callingTools(state:State): | |
format_instructions = parser.get_format_instructions() | |
system_msg = SystemMessage(content=f"""You are a smart movie researcher. | |
1. Your job is to retrieve **only real, verifiable details** from trusted sources. | |
2. **Never assume or generate** fake names, genres, or synopses, details about episodes or seasons. | |
3. Always provide a **brief summary** in **plain text** under the title. | |
### Format Instructions: | |
If the user **is asking about a specific show or programme** β for example, referencing the title, episodes, seasons, cast, language, or summary β format your response like this: | |
{format_instructions} | |
Otherwise, if the latest message does **not** refer to any specific programme or show (e.g. general queries), respond in **plain text** only without JSON formatting. | |
Think carefully before responding: **Is the latest message is referring to a specific show or programme, even indirectly?** Only then use the formatted output.""") | |
human_message = HumanMessage( content=f"{state['messages']}.") | |
llm_with_tools = get_llm() | |
return {"messages": [llm_with_tools.invoke([system_msg]+ [human_message])]} | |
def build_graph(clearMemory: bool = False): | |
global memorysaver | |
if clearMemory: | |
memorysaver = MemorySaver() | |
graph_builder = StateGraph(State) | |
graph_builder.add_node("llm_with_tool", llm_callingTools) | |
graph_builder.add_node("tools", ToolNode(build_tools())) | |
graph_builder.add_edge(START, "llm_with_tool") | |
graph_builder.add_conditional_edges("llm_with_tool", tools_condition) | |
graph_builder.add_edge("tools", "llm_with_tool") | |
graph = graph_builder.compile(checkpointer=memorysaver) | |
return graph | |
def is_pattern_in_string(string: str) -> bool: | |
pattern = r'\bepisode?s?\b|\bseason?s?\b' | |
return re.search(pattern, string) is not None | |
def get_data_by_prompt(prompt: str, thread_id: str): | |
clearMemory = False | |
try: | |
print(f"Prompt: {prompt}") | |
if not is_pattern_in_string(prompt): | |
print("No previous conversation found. Starting fresh.") | |
clearMemory = True | |
graph = build_graph(clearMemory) | |
config = {"configurable": {"thread_id": thread_id}} | |
message_prompt = {"messages": [{"role":"human", "content":prompt}]} | |
data = graph.invoke(message_prompt, config=config) | |
final_output = data["messages"][-1].content | |
if is_pattern_in_string(prompt): | |
try: | |
final_output_new = json.loads(final_output) | |
if isinstance(final_output_new, dict): | |
return Response(content=final_output_new["summary"], media_type="text/markdown") | |
else: | |
return Response(content=final_output, media_type="text/markdown") | |
except json.JSONDecodeError as e: | |
return Response(content=final_output, media_type="text/markdown") | |
return Response(content = data["messages"][-1].content, media_type="text/markdown") | |
except Exception as e: | |
return Response(content=str(e), media_type="text/markdown") | |
if __name__ == "__main__": | |
#get_data_by_prompt("CSI","1") | |
uvicorn.run(app, host= "127.0.0.1", port= 8000) | |