girishwangikar commited on
Commit
61bc6f1
·
verified ·
1 Parent(s): 8953e77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import networkx as nx
4
+ import matplotlib.pyplot as plt
5
+ from langchain_experimental.graph_transformers import LLMGraphTransformer
6
+ from langchain.chains import GraphQAChain
7
+ from langchain_core.documents import Document
8
+ from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_groq import ChatGroq
11
+ import pandas as pd
12
+ from gradio_client import Client
13
+ import numpy as np
14
+ from PIL import Image as PILImage
15
+ import base64
16
+ from io import BytesIO
17
+
18
+ # Set the base directory
19
+ BASE_DIR = os.getcwd()
20
+
21
+ # Set up API keys (you'll need to set these as environment variables in Hugging Face Spaces)
22
+ hf_api_key = os.environ.get('HF_TOKEN')
23
+ groq_api_key = os.environ.get('GROQ_API_KEY')
24
+
25
+ # Set up LLM and Flux client
26
+ llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=groq_api_key)
27
+ flux_client = Client("black-forest-labs/Flux.1-schnell")
28
+
29
+ def create_graph(text):
30
+ documents = [Document(page_content=text)]
31
+ llm_transformer_filtered = LLMGraphTransformer(llm=llm)
32
+ graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(documents)
33
+ graph = NetworkxEntityGraph()
34
+
35
+ for node in graph_documents_filtered[0].nodes:
36
+ graph.add_node(node.id)
37
+
38
+ for edge in graph_documents_filtered[0].relationships:
39
+ graph._graph.add_edge(
40
+ edge.source.id,
41
+ edge.target.id,
42
+ relation=edge.type
43
+ )
44
+
45
+ return graph, graph_documents_filtered
46
+
47
+ def visualize_graph(graph):
48
+ plt.figure(figsize=(12, 8))
49
+ pos = nx.spring_layout(graph._graph)
50
+ nx.draw(graph._graph, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=8, font_weight='bold')
51
+ edge_labels = nx.get_edge_attributes(graph._graph, 'relation')
52
+ nx.draw_networkx_edge_labels(graph._graph, pos, edge_labels=edge_labels, font_size=6)
53
+ plt.title("Graph Visualization")
54
+ plt.axis('off')
55
+
56
+ # Save the plot as an image file
57
+ graph_viz_path = os.path.join(BASE_DIR, 'graph_visualization.png')
58
+ plt.savefig(graph_viz_path)
59
+ plt.close()
60
+
61
+ return graph_viz_path
62
+
63
+ def generate_image(prompt):
64
+ try:
65
+ print(f"Generating image with prompt: {prompt}")
66
+ result = flux_client.predict(
67
+ prompt=prompt,
68
+ seed=0,
69
+ randomize_seed=True,
70
+ width=1024,
71
+ height=1024,
72
+ num_inference_steps=4,
73
+ api_name="/infer"
74
+ )
75
+
76
+ if isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], str):
77
+ img_str = result[0]
78
+ img_str += '=' * (-len(img_str) % 4)
79
+ img_data = base64.b64decode(img_str)
80
+ image = PILImage.open(BytesIO(img_data))
81
+ elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], np.ndarray):
82
+ image = PILImage.fromarray((result[0] * 255).astype(np.uint8))
83
+ elif isinstance(result, PILImage.Image):
84
+ image = result
85
+ else:
86
+ raise ValueError(f"Unexpected result format from flux_client.predict: {type(result)}")
87
+
88
+ image_path = os.path.join(BASE_DIR, 'generated_image.png')
89
+ image.save(image_path)
90
+
91
+ print(f"Image saved to: {image_path}")
92
+ return image_path
93
+ except Exception as e:
94
+ print(f"Error in generate_image: {str(e)}")
95
+ import traceback
96
+ traceback.print_exc()
97
+ return None
98
+
99
+ def process_text(text, question):
100
+ try:
101
+ print("Creating graph...")
102
+ graph, graph_documents_filtered = create_graph(text)
103
+
104
+ print("Setting up GraphQAChain...")
105
+ graph_rag = GraphQAChain.from_llm(
106
+ llm=llm,
107
+ graph=graph,
108
+ verbose=True
109
+ )
110
+
111
+ print("Running question through GraphQAChain...")
112
+ answer = graph_rag.run(question)
113
+ print(f"Answer: {answer}")
114
+
115
+ print("Visualizing graph...")
116
+ graph_viz_path = visualize_graph(graph)
117
+ print(f"Graph visualization saved to: {graph_viz_path}")
118
+
119
+ print("Generating summary...")
120
+ summary_prompt = f"Summarize the following text in one sentence: {text}"
121
+ summary = llm.invoke(summary_prompt).content
122
+ print(f"Summary: {summary}")
123
+
124
+ print("Generating image...")
125
+ image_path = generate_image(summary)
126
+ if image_path and os.path.exists(image_path):
127
+ print(f"Generated image saved to: {image_path}")
128
+ else:
129
+ print("Failed to generate or save image")
130
+
131
+ return answer, graph_viz_path, summary, image_path
132
+ except Exception as e:
133
+ print(f"An error occurred in process_text: {str(e)}")
134
+ import traceback
135
+ traceback.print_exc()
136
+ return str(e), None, str(e), None
137
+
138
+ def ui_function(text, question):
139
+ answer, graph_viz_path, summary, image_path = process_text(text, question)
140
+ if isinstance(answer, str) and answer.startswith("An error occurred"):
141
+ return answer, None, answer, None
142
+ return answer, graph_viz_path, summary, image_path
143
+
144
+ # Create Gradio interface
145
+ iface = gr.Interface(
146
+ fn=ui_function,
147
+ inputs=[
148
+ gr.Textbox(label="Input Text"),
149
+ gr.Textbox(label="Question")
150
+ ],
151
+ outputs=[
152
+ gr.Textbox(label="Answer"),
153
+ gr.Image(label="Graph Visualization", type="filepath"),
154
+ gr.Textbox(label="Summary"),
155
+ gr.Image(label="Generated Image", type="filepath")
156
+ ],
157
+ title="GraphRAG and Image Generation UI",
158
+ description="Enter text to create a graph, ask a question, and generate a relevant image."
159
+ )
160
+
161
+ iface.launch()