SoumyaJ commited on
Commit
83092ff
·
verified ·
1 Parent(s): edd79b4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +148 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import langgraph
2
+ from langgraph.graph import StateGraph, START, END
3
+ from langgraph.prebuilt import ToolNode, tools_condition
4
+ from langchain_groq import ChatGroq
5
+ from typing_extensions import TypedDict, Annotated
6
+ from pydantic import BaseModel, Field
7
+ from langchain_community.utilities import GoogleSerperAPIWrapper, WikipediaAPIWrapper
8
+ from langchain.tools import GoogleSerperRun, WikipediaQueryRun, DuckDuckGoSearchRun
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
+ from langchain_community.tools import TavilySearchResults
11
+ from langgraph.graph.message import add_messages
12
+ from dotenv import load_dotenv
13
+ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
14
+ from langchain_core.output_parsers import PydanticOutputParser
15
+ from langgraph.graph.message import AnyMessage
16
+ from langgraph.checkpoint.memory import MemorySaver
17
+ from pydantic import BaseModel,Field
18
+ from fastapi import FastAPI, Response
19
+ import uvicorn
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ import os
22
+ import warnings
23
+ import json
24
+ import re
25
+
26
+ warnings.filterwarnings("ignore")
27
+
28
+ load_dotenv()
29
+
30
+ app = FastAPI()
31
+ origins = ["*"]
32
+
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=origins,
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"]
39
+ )
40
+
41
+ os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY')
42
+ os.environ["TAVILY_API_KEY"] = os.getenv("TAVILY_API_KEY")
43
+ os.environ["SERPER_API_KEY"] = os.getenv("SERPER_API_KEY")
44
+
45
+ llm = ChatGroq(model = "qwen-qwq-32b",temperature=0.1)
46
+
47
+ memorysaver = MemorySaver()
48
+
49
+ class State(TypedDict):
50
+ messages: Annotated[list[AnyMessage], add_messages]
51
+
52
+ class MovieDetails(BaseModel):
53
+ title: str = Field(..., title="Movie Title", description="The title of the programme for which you want to fetch details.")
54
+ genres:str = Field(..., title="Genres", description="The genres of the programmee.")
55
+ duration:str = Field(..., title="Duration", description="The duration of the programme.")
56
+ synopsis:str = Field(..., title="Synopsis", description="The synopsis of the programme.")
57
+ numberofSeasons: str = Field(..., title="NumberOfSeasons", description="The number of seasons in the programme.")
58
+ numberOfEpisodes: str = Field(..., title="NumberOfEpisodes", description="The number of episodes in the programme.")
59
+ summary:str = Field(..., title="Summary", description="The summary of the programme which contains number of episodes and seasons.")
60
+ source: str = Field(..., title="Source", description="The source(url) from where the information is fetched.")
61
+
62
+ parser = PydanticOutputParser(pydantic_object=MovieDetails)
63
+
64
+ def build_tools():
65
+ serper_wrapper = GoogleSerperAPIWrapper(k = 1)
66
+ serper_run = GoogleSerperRun(api_wrapper = serper_wrapper)
67
+
68
+ tools = [TavilySearchResults(max_results=2)]
69
+ return tools
70
+
71
+ def get_llm():
72
+ tools = build_tools()
73
+ llm_output = llm.with_structured_output(MovieDetails)
74
+ llm_with_tools = llm.bind_tools(tools)
75
+ return llm_with_tools
76
+
77
+ def llm_callingTools(state:State):
78
+
79
+ format_instructions = parser.get_format_instructions()
80
+
81
+ system_msg = SystemMessage(content=f"""You are a smart movie researcher.
82
+ 1. Your job is to retrieve **only real, verifiable details** from trusted sources.
83
+ 2. **Never assume or generate** fake names, genres, or synopses, details about episodes or seasons.
84
+ 3. Always provide a **brief summary** in **plain text** under the title.
85
+
86
+ ### Format Instructions:
87
+ If the user **is asking about a specific show or programme** — for example, referencing the title, episodes, seasons, cast, language, or summary — format your response like this:
88
+ {format_instructions}
89
+
90
+ Otherwise, if the latest message does **not** refer to any specific programme or show (e.g. general queries), respond in **plain text** only without JSON formatting.
91
+
92
+ Think carefully before responding: **Is the latest message is referring to a specific show or programme, even indirectly?** Only then use the formatted output.""")
93
+
94
+ human_message = HumanMessage( content=f"{state['messages']}.")
95
+ llm_with_tools = get_llm()
96
+ return {"messages": [llm_with_tools.invoke([system_msg]+ [human_message])]}
97
+
98
+ def build_graph(clearMemory: bool = False):
99
+ global memorysaver
100
+ if clearMemory:
101
+ memorysaver = MemorySaver()
102
+ graph_builder = StateGraph(State)
103
+ graph_builder.add_node("llm_with_tool", llm_callingTools)
104
+ graph_builder.add_node("tools", ToolNode(build_tools()))
105
+ graph_builder.add_edge(START, "llm_with_tool")
106
+ graph_builder.add_conditional_edges("llm_with_tool", tools_condition)
107
+ graph_builder.add_edge("tools", "llm_with_tool")
108
+
109
+ graph = graph_builder.compile(checkpointer=memorysaver)
110
+ return graph
111
+
112
+ def is_pattern_in_string(string: str) -> bool:
113
+ pattern = r'\bepisode?s?\b|\bseason?s?\b'
114
+ return re.search(pattern, string) is not None
115
+
116
+ @app.post("/api/v1/get_programme_info")
117
+ def get_data_by_prompt(prompt: str, thread_id: str):
118
+ clearMemory = False
119
+ try:
120
+ print(f"Prompt: {prompt}")
121
+ if not is_pattern_in_string(prompt):
122
+ print("No previous conversation found. Starting fresh.")
123
+ clearMemory = True
124
+ graph = build_graph(clearMemory)
125
+ config = {"configurable": {"thread_id": thread_id}}
126
+
127
+ message_prompt = {"messages": [{"role":"human", "content":prompt}]}
128
+ data = graph.invoke(message_prompt, config=config)
129
+ final_output = data["messages"][-1].content
130
+ if is_pattern_in_string(prompt):
131
+ try:
132
+ final_output_new = json.loads(final_output)
133
+ if isinstance(final_output_new, dict):
134
+ return Response(content=final_output_new["summary"], media_type="text/markdown")
135
+ else:
136
+ return Response(content=final_output, media_type="text/markdown")
137
+ except json.JSONDecodeError as e:
138
+ return Response(content=final_output, media_type="text/markdown")
139
+ return Response(content = data["messages"][-1].content, media_type="text/markdown")
140
+
141
+ except Exception as e:
142
+ return Response(content=str(e), media_type="text/markdown")
143
+
144
+ if __name__ == "__main__":
145
+ #get_data_by_prompt("CSI","1")
146
+ uvicorn.run(app, host= "127.0.0.1", port= 8000)
147
+
148
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ langgraph
2
+ python-dotenv
3
+ langchain
4
+ langchain-community
5
+ langchain-groq
6
+ wikipedia
7
+ duckduckgo-search
8
+ fastapi
9
+ uvicorn