mtyrrell commited on
Commit
2d87d9b
·
1 Parent(s): 588173c

max context length limit

Browse files
Files changed (2) hide show
  1. app/main.py +12 -4
  2. params.cfg +3 -0
app/main.py CHANGED
@@ -23,6 +23,7 @@ config = getconfig("params.cfg")
23
  RETRIEVER = config.get("retriever", "RETRIEVER")
24
  GENERATOR = config.get("generator", "GENERATOR")
25
  INGESTOR = config.get("ingestor", "INGESTOR")
 
26
 
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
  logger = logging.getLogger(__name__)
@@ -83,7 +84,7 @@ def ingest_node(state: GraphState) -> GraphState:
83
  try:
84
  # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
85
  ingestor_context = client.predict(
86
- file(tmp_file_path), # Use gradio_client.file() to properly format
87
  api_name="/ingest"
88
  )
89
 
@@ -168,13 +169,20 @@ def generate_node(state: GraphState) -> GraphState:
168
  retrieved_context = state.get("context", "")
169
  ingestor_context = state.get("ingestor_context", "")
170
 
 
 
 
171
  combined_context = ""
172
  if ingestor_context and retrieved_context:
173
- combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_context}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_context}"
 
 
 
174
  elif ingestor_context:
175
- combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_context}"
 
176
  elif retrieved_context:
177
- combined_context = retrieved_context
178
 
179
  client = Client(GENERATOR)
180
  result = client.predict(
 
23
  RETRIEVER = config.get("retriever", "RETRIEVER")
24
  GENERATOR = config.get("generator", "GENERATOR")
25
  INGESTOR = config.get("ingestor", "INGESTOR")
26
+ MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
27
 
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29
  logger = logging.getLogger(__name__)
 
84
  try:
85
  # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
86
  ingestor_context = client.predict(
87
+ file=tmp_file_path,
88
  api_name="/ingest"
89
  )
90
 
 
169
  retrieved_context = state.get("context", "")
170
  ingestor_context = state.get("ingestor_context", "")
171
 
172
+ # Limit context size to prevent token overflow
173
+ MAX_CONTEXT_CHARS = int(MAX_CONTEXT_CHARS) # Adjust based on your model's limits
174
+
175
  combined_context = ""
176
  if ingestor_context and retrieved_context:
177
+ # Prioritize ingestor context, truncate if needed
178
+ ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context
179
+ retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context
180
+ combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}"
181
  elif ingestor_context:
182
+ ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context
183
+ combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}"
184
  elif retrieved_context:
185
+ combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context
186
 
187
  client = Client(GENERATOR)
188
  result = client.predict(
params.cfg CHANGED
@@ -6,3 +6,6 @@ GENERATOR = giz/chatfed_generator
6
 
7
  [ingestor]
8
  INGESTOR = mtyrrell/chatfed_ingestor
 
 
 
 
6
 
7
  [ingestor]
8
  INGESTOR = mtyrrell/chatfed_ingestor
9
+
10
+ [general]
11
+ MAX_CONTEXT_CHARS = 15000