mtyrrell commited on
Commit
c245449
·
1 Parent(s): 2d87d9b

ts startup

Browse files
Files changed (1) hide show
  1. app/main.py +106 -25
app/main.py CHANGED
@@ -20,9 +20,9 @@ import tempfile
20
  from utils import getconfig
21
 
22
  config = getconfig("params.cfg")
23
- RETRIEVER = config.get("retriever", "RETRIEVER")
24
- GENERATOR = config.get("generator", "GENERATOR")
25
- INGESTOR = config.get("ingestor", "INGESTOR")
26
  MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
27
 
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -84,7 +84,7 @@ def ingest_node(state: GraphState) -> GraphState:
84
  try:
85
  # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
86
  ingestor_context = client.predict(
87
- file=tmp_file_path,
88
  api_name="/ingest"
89
  )
90
 
@@ -122,6 +122,52 @@ def ingest_node(state: GraphState) -> GraphState:
122
  "ingestion_error": str(e)
123
  })
124
  return {"ingestor_context": "", "metadata": metadata}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def retrieve_node(state: GraphState) -> GraphState:
127
  start_time = datetime.now()
@@ -170,8 +216,8 @@ def generate_node(state: GraphState) -> GraphState:
170
  ingestor_context = state.get("ingestor_context", "")
171
 
172
  # Limit context size to prevent token overflow
173
- MAX_CONTEXT_CHARS = int(MAX_CONTEXT_CHARS) # Adjust based on your model's limits
174
-
175
  combined_context = ""
176
  if ingestor_context and retrieved_context:
177
  # Prioritize ingestor context, truncate if needed
@@ -355,7 +401,6 @@ def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
355
  )
356
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
357
 
358
- # This is not working currently... Problematic because HF doesn't allow > 1 port open at the same time
359
  def create_gradio_interface():
360
  with gr.Blocks(title="ChatFed Orchestrator") as demo:
361
  gr.Markdown("# ChatFed Orchestrator")
@@ -416,25 +461,42 @@ async def root():
416
  }
417
  }
418
 
419
- # LangServe routes (these are the main endpoints)
420
- add_routes(
421
- app,
422
- RunnableLambda(process_query_langserve),
423
- path="/chatfed",
424
- input_type=ChatFedInput,
425
- output_type=ChatFedOutput
426
- )
427
-
428
- add_routes(
429
- app,
430
- RunnableLambda(chatui_adapter),
431
- path="/chatfed-ui-stream",
432
- input_type=ChatUIInput,
433
- output_type=str,
434
- enable_feedback_endpoint=True,
435
- enable_public_trace_link_endpoint=True,
436
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
 
438
  @app.post("/chatfed-with-file")
439
  async def chatfed_with_file(
440
  query: str = Form(...),
@@ -469,6 +531,25 @@ async def chatfed_with_file(
469
 
470
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  def run_gradio_server():
473
  demo = create_gradio_interface()
474
  demo.launch(
 
20
  from utils import getconfig
21
 
22
  config = getconfig("params.cfg")
23
+ RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
24
+ GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
25
+ INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
26
  MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
27
 
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
84
  try:
85
  # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
86
  ingestor_context = client.predict(
87
+ file(tmp_file_path), # Use gradio_client.file() to properly format
88
  api_name="/ingest"
89
  )
90
 
 
122
  "ingestion_error": str(e)
123
  })
124
  return {"ingestor_context": "", "metadata": metadata}
125
+
126
+ try:
127
+ client = Client(INGESTOR)
128
+
129
+ # Create a temporary file to upload
130
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
131
+ tmp_file.write(state["file_content"])
132
+ tmp_file_path = tmp_file.name
133
+
134
+ try:
135
+ # Call the ingestor's ingest endpoint - returns context directly
136
+ ingestor_context = client.predict(
137
+ file=tmp_file_path,
138
+ api_name="/ingest"
139
+ )
140
+
141
+ logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
142
+
143
+ finally:
144
+ # Clean up temporary file
145
+ os.unlink(tmp_file_path)
146
+
147
+ duration = (datetime.now() - start_time).total_seconds()
148
+ metadata = state.get("metadata", {})
149
+ metadata.update({
150
+ "ingestion_duration": duration,
151
+ "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
152
+ "ingestion_success": True
153
+ })
154
+
155
+ return {
156
+ "ingestor_context": ingestor_context,
157
+ "metadata": metadata
158
+ }
159
+
160
+ except Exception as e:
161
+ duration = (datetime.now() - start_time).total_seconds()
162
+ logger.error(f"Ingestion failed: {str(e)}")
163
+
164
+ metadata = state.get("metadata", {})
165
+ metadata.update({
166
+ "ingestion_duration": duration,
167
+ "ingestion_success": False,
168
+ "ingestion_error": str(e)
169
+ })
170
+ return {"ingestor_context": "", "metadata": metadata}
171
 
172
  def retrieve_node(state: GraphState) -> GraphState:
173
  start_time = datetime.now()
 
216
  ingestor_context = state.get("ingestor_context", "")
217
 
218
  # Limit context size to prevent token overflow
219
+ MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
220
+
221
  combined_context = ""
222
  if ingestor_context and retrieved_context:
223
  # Prioritize ingestor context, truncate if needed
 
401
  )
402
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
403
 
 
404
  def create_gradio_interface():
405
  with gr.Blocks(title="ChatFed Orchestrator") as demo:
406
  gr.Markdown("# ChatFed Orchestrator")
 
461
  }
462
  }
463
 
464
+ # Additional endpoint for file uploads via API
465
+ @app.post("/chatfed-with-file")
466
+ async def chatfed_with_file(
467
+ query: str = Form(...),
468
+ file: Optional[UploadFile] = File(None),
469
+ reports_filter: Optional[str] = Form(""),
470
+ sources_filter: Optional[str] = Form(""),
471
+ subtype_filter: Optional[str] = Form(""),
472
+ year_filter: Optional[str] = Form(""),
473
+ session_id: Optional[str] = Form(None),
474
+ user_id: Optional[str] = Form(None)
475
+ ):
476
+ """Endpoint for queries with optional file attachments"""
477
+ file_content = None
478
+ filename = None
479
+
480
+ if file:
481
+ file_content = await file.read()
482
+ filename = file.filename
483
+
484
+ result = process_query_core(
485
+ query=query,
486
+ reports_filter=reports_filter,
487
+ sources_filter=sources_filter,
488
+ subtype_filter=subtype_filter,
489
+ year_filter=year_filter,
490
+ file_content=file_content,
491
+ filename=filename,
492
+ session_id=session_id,
493
+ user_id=user_id,
494
+ return_metadata=True
495
+ )
496
+
497
+ return ChatFedOutput(result=result["result"], metadata=result["metadata"])
498
 
499
+ # Additional endpoint for file uploads via API
500
  @app.post("/chatfed-with-file")
501
  async def chatfed_with_file(
502
  query: str = Form(...),
 
531
 
532
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
533
 
534
+ # LangServe routes (these are the main endpoints)
535
+ add_routes(
536
+ app,
537
+ RunnableLambda(process_query_langserve),
538
+ path="/chatfed",
539
+ input_type=ChatFedInput,
540
+ output_type=ChatFedOutput
541
+ )
542
+
543
+ add_routes(
544
+ app,
545
+ RunnableLambda(chatui_adapter),
546
+ path="/chatfed-ui-stream",
547
+ input_type=ChatUIInput,
548
+ output_type=str,
549
+ enable_feedback_endpoint=True,
550
+ enable_public_trace_link_endpoint=True,
551
+ )
552
+
553
  def run_gradio_server():
554
  demo = create_gradio_interface()
555
  demo.launch(