Spaces:
Sleeping
Sleeping
Monideep Chakraborti
commited on
Commit
·
36b34ac
1
Parent(s):
1ffce95
Deploy GQuery AI - Biomedical Research Assistant with Multi-Database Integration
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +47 -6
- app.py +23 -0
- gquery/src/gquery/__init__.py +29 -0
- gquery/src/gquery/__pycache__/__init__.cpython-310 2.pyc +0 -0
- gquery/src/gquery/__pycache__/__init__.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__init__.py +42 -0
- gquery/src/gquery/agents/__pycache__/__init__.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/__init__.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/biomedical_guardrails.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/biomedical_guardrails.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/config.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/config.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/enhanced_orchestrator.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/enhanced_orchestrator.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/entity_resolver.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/entity_resolver.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/orchestrator.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/orchestrator.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/query_analyzer.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/query_analyzer.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/synthesis.cpython-310 2.pyc +0 -0
- gquery/src/gquery/agents/__pycache__/synthesis.cpython-310.pyc +0 -0
- gquery/src/gquery/agents/biomedical_guardrails.py +317 -0
- gquery/src/gquery/agents/config.py +191 -0
- gquery/src/gquery/agents/enhanced_orchestrator.py +947 -0
- gquery/src/gquery/agents/entity_resolver.py +452 -0
- gquery/src/gquery/agents/orchestrator.py +627 -0
- gquery/src/gquery/agents/query_analyzer.py +289 -0
- gquery/src/gquery/agents/synthesis.py +429 -0
- gquery/src/gquery/cli.py +1027 -0
- gquery/src/gquery/config/__init__.py +6 -0
- gquery/src/gquery/config/__pycache__/__init__.cpython-310 2.pyc +0 -0
- gquery/src/gquery/config/__pycache__/__init__.cpython-310.pyc +0 -0
- gquery/src/gquery/config/__pycache__/settings.cpython-310 2.pyc +0 -0
- gquery/src/gquery/config/__pycache__/settings.cpython-310.pyc +0 -0
- gquery/src/gquery/config/settings.py +200 -0
- gquery/src/gquery/interfaces/__init__.py +6 -0
- gquery/src/gquery/models/__init__.py +40 -0
- gquery/src/gquery/models/__pycache__/__init__.cpython-310 2.pyc +0 -0
- gquery/src/gquery/models/__pycache__/__init__.cpython-310.pyc +0 -0
- gquery/src/gquery/models/__pycache__/base.cpython-310 2.pyc +0 -0
- gquery/src/gquery/models/__pycache__/base.cpython-310.pyc +0 -0
- gquery/src/gquery/models/__pycache__/clinvar.cpython-310 2.pyc +0 -0
- gquery/src/gquery/models/__pycache__/clinvar.cpython-310.pyc +0 -0
- gquery/src/gquery/models/__pycache__/datasets.cpython-310 2.pyc +0 -0
- gquery/src/gquery/models/__pycache__/datasets.cpython-310.pyc +0 -0
- gquery/src/gquery/models/__pycache__/pmc.cpython-310 2.pyc +0 -0
- gquery/src/gquery/models/__pycache__/pmc.cpython-310.pyc +0 -0
- gquery/src/gquery/models/base.py +89 -0
- gquery/src/gquery/models/clinvar.py +370 -0
README.md
CHANGED
@@ -1,12 +1,53 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: GQuery AI - Biomedical Research Assistant
|
3 |
+
emoji: 🧬
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: "4.0.0"
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
+
# 🧬 GQuery AI - Intelligent Biomedical Research Assistant
|
14 |
+
|
15 |
+
**Comprehensive biomedical research powered by NCBI databases and advanced AI.**
|
16 |
+
|
17 |
+
## ✨ Features
|
18 |
+
|
19 |
+
- **🔍 Multi-Database Search**: Query PubMed Central, ClinVar, and NCBI Datasets simultaneously
|
20 |
+
- **🧠 Enhanced AI Analysis**: Deep scientific synthesis with comprehensive molecular biology insights
|
21 |
+
- **🎯 Smart Clarification**: Intelligent query refinement for precise results
|
22 |
+
- **📚 Clickable Sources**: Direct links to research papers and genetic databases
|
23 |
+
- **🔬 Professional Analysis**: Detailed pathophysiology, genomics, and clinical applications
|
24 |
+
- **💬 Conversational Memory**: Context-aware follow-up questions
|
25 |
+
|
26 |
+
## 🚀 How to Use
|
27 |
+
|
28 |
+
1. **Enter your biomedical query** (genes, diseases, drugs, or treatments)
|
29 |
+
2. **Clarify if prompted** for more targeted results
|
30 |
+
3. **Explore comprehensive analysis** with scientific depth
|
31 |
+
4. **Click source links** to access original research
|
32 |
+
5. **Use follow-up suggestions** for deeper investigation
|
33 |
+
|
34 |
+
## 🧬 Example Queries
|
35 |
+
|
36 |
+
- **Gene Analysis**: "BRCA1", "TP53", "CFTR"
|
37 |
+
- **Disease Research**: "Type 2 diabetes pathophysiology", "Alzheimer's disease"
|
38 |
+
- **Drug Information**: "metformin", "insulin therapy"
|
39 |
+
- **Treatment Research**: "CRISPR gene therapy", "immunotherapy"
|
40 |
+
|
41 |
+
## 🔬 Data Sources
|
42 |
+
|
43 |
+
- **PubMed Central**: Latest research publications
|
44 |
+
- **ClinVar**: Genetic variant database
|
45 |
+
- **NCBI Datasets**: Genomic and expression data
|
46 |
+
|
47 |
+
## ⚠️ Important Note
|
48 |
+
|
49 |
+
This tool is for research and educational purposes only. Always consult qualified healthcare professionals for medical decisions.
|
50 |
+
|
51 |
+
---
|
52 |
+
|
53 |
+
*Powered by advanced AI and real-time NCBI database integration*
|
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
GQuery AI - HuggingFace Spaces Deployment
|
4 |
+
Intelligent Biomedical Research Assistant
|
5 |
+
|
6 |
+
This is the main entry point for the HuggingFace Spaces deployment.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
# Suppress warnings for cleaner deployment
|
14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
15 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
16 |
+
|
17 |
+
# Add the gquery package to the Python path
|
18 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "gquery", "src"))
|
19 |
+
|
20 |
+
# Import and run the main Gradio app
|
21 |
+
if __name__ == "__main__":
|
22 |
+
from improved_gradio_app import main
|
23 |
+
main()
|
gquery/src/gquery/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
GQuery AI - Biomedical Research Platform
|
3 |
+
|
4 |
+
A production-ready, scalable biomedical research platform integrating NCBI databases
|
5 |
+
to solve the data silo problem by intelligently connecting PubMed Central (PMC),
|
6 |
+
ClinVar, and NCBI Datasets.
|
7 |
+
|
8 |
+
Version: 0.1.0
|
9 |
+
Author: Monideep Chakraborti
|
10 |
+
License: MIT
|
11 |
+
"""
|
12 |
+
|
13 |
+
__version__ = "0.1.0"
|
14 |
+
__author__ = "Monideep Chakraborti"
|
15 |
+
__license__ = "MIT"
|
16 |
+
|
17 |
+
# Core exports
|
18 |
+
from gquery.config.settings import get_settings
|
19 |
+
from gquery.models.base import BaseModel
|
20 |
+
from gquery.utils.logger import get_logger
|
21 |
+
|
22 |
+
__all__ = [
|
23 |
+
"__version__",
|
24 |
+
"__author__",
|
25 |
+
"__license__",
|
26 |
+
"get_settings",
|
27 |
+
"BaseModel",
|
28 |
+
"get_logger",
|
29 |
+
]
|
gquery/src/gquery/__pycache__/__init__.cpython-310 2.pyc
ADDED
Binary file (807 Bytes). View file
|
|
gquery/src/gquery/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (807 Bytes). View file
|
|
gquery/src/gquery/agents/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
GQuery AI Agent Module
|
3 |
+
|
4 |
+
This module contains the core AI agent logic for Phase 2:
|
5 |
+
- Query analysis and intent detection (Feature 2.3)
|
6 |
+
- Multi-database orchestration (Feature 2.1)
|
7 |
+
- Cross-database synthesis (Feature 2.2)
|
8 |
+
- Entity resolution and linking (Feature 2.4)
|
9 |
+
"""
|
10 |
+
|
11 |
+
from .config import AgentConfig, QueryType, DatabasePriority
|
12 |
+
from .query_analyzer import QueryAnalyzer, QueryAnalysis, analyze_query_intent
|
13 |
+
from .orchestrator import GQueryOrchestrator, OrchestrationResult, orchestrate_query
|
14 |
+
from .synthesis import DataSynthesizer, SynthesisResult, synthesize_biomedical_data
|
15 |
+
from .entity_resolver import EntityResolver, ResolvedEntity, resolve_biomedical_entities
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
# Configuration
|
19 |
+
"AgentConfig",
|
20 |
+
"QueryType",
|
21 |
+
"DatabasePriority",
|
22 |
+
|
23 |
+
# Query Analysis (Feature 2.3)
|
24 |
+
"QueryAnalyzer",
|
25 |
+
"QueryAnalysis",
|
26 |
+
"analyze_query_intent",
|
27 |
+
|
28 |
+
# Orchestration (Feature 2.1)
|
29 |
+
"GQueryOrchestrator",
|
30 |
+
"OrchestrationResult",
|
31 |
+
"orchestrate_query",
|
32 |
+
|
33 |
+
# Synthesis (Feature 2.2)
|
34 |
+
"DataSynthesizer",
|
35 |
+
"SynthesisResult",
|
36 |
+
"synthesize_biomedical_data",
|
37 |
+
|
38 |
+
# Entity Resolution (Feature 2.4)
|
39 |
+
"EntityResolver",
|
40 |
+
"ResolvedEntity",
|
41 |
+
"resolve_biomedical_entities",
|
42 |
+
]
|
gquery/src/gquery/agents/__pycache__/__init__.cpython-310 2.pyc
ADDED
Binary file (1.11 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/biomedical_guardrails.cpython-310 2.pyc
ADDED
Binary file (10.9 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/biomedical_guardrails.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/config.cpython-310 2.pyc
ADDED
Binary file (6.51 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/config.cpython-310.pyc
ADDED
Binary file (6.51 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/enhanced_orchestrator.cpython-310 2.pyc
ADDED
Binary file (26.7 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/enhanced_orchestrator.cpython-310.pyc
ADDED
Binary file (30.4 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/entity_resolver.cpython-310 2.pyc
ADDED
Binary file (11.6 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/entity_resolver.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/orchestrator.cpython-310 2.pyc
ADDED
Binary file (16.4 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/orchestrator.cpython-310.pyc
ADDED
Binary file (16.4 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/query_analyzer.cpython-310 2.pyc
ADDED
Binary file (8.42 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/query_analyzer.cpython-310.pyc
ADDED
Binary file (8.42 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/synthesis.cpython-310 2.pyc
ADDED
Binary file (11.5 kB). View file
|
|
gquery/src/gquery/agents/__pycache__/synthesis.cpython-310.pyc
ADDED
Binary file (11.5 kB). View file
|
|
gquery/src/gquery/agents/biomedical_guardrails.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Biomedical Guardrails Module
|
3 |
+
|
4 |
+
Implements Feature 3: Biomedical Guardrails Implementation
|
5 |
+
- Validates that queries are within the biomedical domain
|
6 |
+
- Provides polite rejection for out-of-scope queries
|
7 |
+
- Ensures trust and safety for the GQuery AI system
|
8 |
+
"""
|
9 |
+
|
10 |
+
import re
|
11 |
+
import logging
|
12 |
+
from typing import Dict, List, Optional, Tuple
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from datetime import datetime
|
15 |
+
from enum import Enum
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class QueryDomain(Enum):
|
21 |
+
"""Classification of query domains."""
|
22 |
+
BIOMEDICAL = "biomedical"
|
23 |
+
NON_BIOMEDICAL = "non_biomedical"
|
24 |
+
AMBIGUOUS = "ambiguous"
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class GuardrailResult:
|
29 |
+
"""Result from guardrail validation."""
|
30 |
+
is_valid: bool
|
31 |
+
domain: QueryDomain
|
32 |
+
confidence: float
|
33 |
+
rejection_message: Optional[str] = None
|
34 |
+
detected_categories: List[str] = None
|
35 |
+
biomedical_score: float = 0.0
|
36 |
+
non_biomedical_score: float = 0.0
|
37 |
+
processing_time_ms: Optional[int] = None
|
38 |
+
|
39 |
+
|
40 |
+
class BiomedicalGuardrails:
|
41 |
+
"""
|
42 |
+
Validates queries to ensure they are within the biomedical domain.
|
43 |
+
|
44 |
+
This is the highest priority feature based on manager feedback:
|
45 |
+
"TRUST IS THE MOST IMPORTANT THING"
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self):
|
49 |
+
self.biomedical_keywords = {
|
50 |
+
# Core biomedical terms
|
51 |
+
'genes': ['gene', 'genes', 'genetic', 'genomic', 'genome', 'dna', 'rna', 'mrna', 'allele'],
|
52 |
+
'proteins': ['protein', 'proteins', 'enzyme', 'enzymes', 'antibody', 'antibodies', 'peptide'],
|
53 |
+
'diseases': ['disease', 'diseases', 'disorder', 'syndrome', 'condition', 'illness', 'pathology'],
|
54 |
+
'medical': ['medical', 'medicine', 'clinical', 'therapy', 'treatment', 'diagnosis', 'patient'],
|
55 |
+
'biology': ['cell', 'cellular', 'molecular', 'biology', 'biological', 'biochemistry', 'physiology'],
|
56 |
+
'pharmacology': ['drug', 'drugs', 'medication', 'pharmaceutical', 'compound', 'therapeutic'],
|
57 |
+
'anatomy': ['organ', 'tissue', 'blood', 'brain', 'heart', 'liver', 'kidney', 'muscle'],
|
58 |
+
'pathology': ['cancer', 'tumor', 'carcinoma', 'mutation', 'variant', 'pathogenic', 'benign'],
|
59 |
+
'research': ['study', 'research', 'clinical trial', 'experiment', 'analysis', 'publication'],
|
60 |
+
'databases': ['pubmed', 'pmc', 'clinvar', 'ncbi', 'datasets', 'genbank', 'omim', 'hgnc']
|
61 |
+
}
|
62 |
+
|
63 |
+
self.non_biomedical_patterns = {
|
64 |
+
# Clear non-medical categories
|
65 |
+
'weather': ['weather', 'temperature', 'rain', 'snow', 'climate', 'forecast', 'storm', 'sunny'],
|
66 |
+
'sports': ['football', 'basketball', 'soccer', 'baseball', 'tennis', 'golf', 'hockey', 'game', 'team', 'player'],
|
67 |
+
'entertainment': ['movie', 'film', 'music', 'song', 'actor', 'actress', 'celebrity', 'tv show', 'netflix'],
|
68 |
+
'food': ['recipe', 'cooking', 'food', 'restaurant', 'meal', 'dinner', 'lunch', 'breakfast'],
|
69 |
+
'politics': ['president', 'election', 'vote', 'political', 'government', 'congress', 'senate'],
|
70 |
+
'technology': ['computer', 'software', 'app', 'website', 'internet', 'phone', 'laptop'],
|
71 |
+
'travel': ['vacation', 'hotel', 'flight', 'travel', 'trip', 'tourism', 'destination'],
|
72 |
+
'business': ['stock', 'investment', 'company', 'business', 'market', 'economy', 'finance'],
|
73 |
+
'education': ['school', 'university', 'college', 'student', 'teacher', 'homework', 'class'],
|
74 |
+
'general': ['what is', 'how to', 'where is', 'when was', 'who is', 'why does']
|
75 |
+
}
|
76 |
+
|
77 |
+
# Special cases that need careful handling
|
78 |
+
self.ambiguous_terms = {
|
79 |
+
'heart': 'Could refer to cardiac medicine or emotional concept',
|
80 |
+
'cell': 'Could refer to biological cells or phone cells',
|
81 |
+
'virus': 'Could refer to biological virus or computer virus',
|
82 |
+
'depression': 'Could refer to mental health condition or economic depression',
|
83 |
+
'pressure': 'Could refer to blood pressure or physical pressure'
|
84 |
+
}
|
85 |
+
|
86 |
+
# Known biomedical entities (genes, diseases, etc.)
|
87 |
+
self.known_biomedical_entities = {
|
88 |
+
# Common genes
|
89 |
+
'brca1', 'brca2', 'tp53', 'cftr', 'apoe', 'mthfr', 'vegf', 'egfr', 'kras', 'myh7',
|
90 |
+
'ldlr', 'app', 'psen1', 'psen2', 'sod1', 'fmr1', 'dmd', 'f8', 'f9', 'vwf',
|
91 |
+
# Common diseases
|
92 |
+
'diabetes', 'cancer', 'alzheimer', 'parkinsons', 'huntington', 'cystic fibrosis', 'tuberculosis', 'tb',
|
93 |
+
'hemophilia', 'sickle cell', 'thalassemia', 'muscular dystrophy',
|
94 |
+
# Common drugs
|
95 |
+
'aspirin', 'metformin', 'insulin', 'warfarin', 'statin', 'penicillin',
|
96 |
+
# Medical specialties
|
97 |
+
'cardiology', 'oncology', 'neurology', 'genetics', 'immunology', 'pharmacology'
|
98 |
+
}
|
99 |
+
|
100 |
+
def validate_query(self, query: str) -> GuardrailResult:
|
101 |
+
"""
|
102 |
+
Validate if a query is within the biomedical domain.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
query: The user's input query
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
GuardrailResult with validation decision and details
|
109 |
+
"""
|
110 |
+
start_time = datetime.now()
|
111 |
+
|
112 |
+
if not query or not query.strip():
|
113 |
+
return GuardrailResult(
|
114 |
+
is_valid=False,
|
115 |
+
domain=QueryDomain.NON_BIOMEDICAL,
|
116 |
+
confidence=1.0,
|
117 |
+
rejection_message="Please provide a question about biomedical topics.",
|
118 |
+
processing_time_ms=0
|
119 |
+
)
|
120 |
+
|
121 |
+
query_lower = query.lower().strip()
|
122 |
+
|
123 |
+
# Check for known biomedical entities first
|
124 |
+
biomedical_score = self._calculate_biomedical_score(query_lower)
|
125 |
+
non_biomedical_score = self._calculate_non_biomedical_score(query_lower)
|
126 |
+
|
127 |
+
# Determine domain based on scores
|
128 |
+
domain, is_valid, confidence = self._classify_domain(
|
129 |
+
biomedical_score, non_biomedical_score, query_lower
|
130 |
+
)
|
131 |
+
|
132 |
+
# Generate appropriate response
|
133 |
+
rejection_message = None
|
134 |
+
if not is_valid:
|
135 |
+
rejection_message = self._generate_rejection_message(query_lower, domain)
|
136 |
+
|
137 |
+
processing_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
138 |
+
|
139 |
+
return GuardrailResult(
|
140 |
+
is_valid=is_valid,
|
141 |
+
domain=domain,
|
142 |
+
confidence=confidence,
|
143 |
+
rejection_message=rejection_message,
|
144 |
+
biomedical_score=biomedical_score,
|
145 |
+
non_biomedical_score=non_biomedical_score,
|
146 |
+
processing_time_ms=processing_time
|
147 |
+
)
|
148 |
+
|
149 |
+
def _calculate_biomedical_score(self, query: str) -> float:
|
150 |
+
"""Calculate how biomedical a query appears to be."""
|
151 |
+
score = 0.0
|
152 |
+
word_count = len(query.split())
|
153 |
+
|
154 |
+
# Check for known biomedical entities (high weight)
|
155 |
+
for entity in self.known_biomedical_entities:
|
156 |
+
if entity in query:
|
157 |
+
score += 0.8
|
158 |
+
|
159 |
+
# Check for biomedical keywords by category
|
160 |
+
for category, keywords in self.biomedical_keywords.items():
|
161 |
+
for keyword in keywords:
|
162 |
+
if keyword in query:
|
163 |
+
if category in ['genes', 'diseases', 'medical']:
|
164 |
+
score += 0.6 # High weight for core categories
|
165 |
+
elif category in ['proteins', 'pharmacology']:
|
166 |
+
score += 0.5 # Medium weight
|
167 |
+
else:
|
168 |
+
score += 0.3 # Lower weight for general bio terms
|
169 |
+
|
170 |
+
# Normalize by query length (longer queries get some benefit)
|
171 |
+
if word_count > 0:
|
172 |
+
score = min(score / word_count, 1.0)
|
173 |
+
|
174 |
+
return score
|
175 |
+
|
176 |
+
def _calculate_non_biomedical_score(self, query: str) -> float:
|
177 |
+
"""Calculate how non-biomedical a query appears to be."""
|
178 |
+
score = 0.0
|
179 |
+
word_count = len(query.split())
|
180 |
+
|
181 |
+
# Check for non-biomedical patterns
|
182 |
+
for category, patterns in self.non_biomedical_patterns.items():
|
183 |
+
for pattern in patterns:
|
184 |
+
if pattern in query:
|
185 |
+
if category in ['weather', 'sports', 'entertainment']:
|
186 |
+
score += 0.8 # High weight for clearly non-medical
|
187 |
+
elif category in ['food', 'politics', 'technology']:
|
188 |
+
score += 0.6 # Medium weight
|
189 |
+
else:
|
190 |
+
score += 0.4 # Lower weight for potentially ambiguous
|
191 |
+
|
192 |
+
# Normalize by query length
|
193 |
+
if word_count > 0:
|
194 |
+
score = min(score / word_count, 1.0)
|
195 |
+
|
196 |
+
return score
|
197 |
+
|
198 |
+
def _classify_domain(self, bio_score: float, non_bio_score: float, query: str) -> Tuple[QueryDomain, bool, float]:
|
199 |
+
"""Classify the query domain based on scores."""
|
200 |
+
|
201 |
+
# Clear biomedical indicators
|
202 |
+
if bio_score > 0.4:
|
203 |
+
return QueryDomain.BIOMEDICAL, True, min(bio_score * 1.2, 1.0)
|
204 |
+
|
205 |
+
# Clear non-biomedical indicators
|
206 |
+
if non_bio_score > 0.4:
|
207 |
+
return QueryDomain.NON_BIOMEDICAL, False, min(non_bio_score * 1.2, 1.0)
|
208 |
+
|
209 |
+
# Check for ambiguous terms that might be biomedical
|
210 |
+
for term, description in self.ambiguous_terms.items():
|
211 |
+
if term in query:
|
212 |
+
# Give benefit of doubt for ambiguous terms in biomedical context
|
213 |
+
return QueryDomain.AMBIGUOUS, True, 0.6
|
214 |
+
|
215 |
+
# If very short query with no clear indicators, be cautious but allow
|
216 |
+
if len(query.split()) <= 2 and bio_score > 0.1:
|
217 |
+
return QueryDomain.AMBIGUOUS, True, 0.5
|
218 |
+
|
219 |
+
# Default: if we can't classify clearly, err on side of rejection for safety
|
220 |
+
if bio_score < 0.1 and non_bio_score < 0.1:
|
221 |
+
return QueryDomain.NON_BIOMEDICAL, False, 0.7
|
222 |
+
|
223 |
+
# Slight edge to biomedical if scores are close
|
224 |
+
if bio_score >= non_bio_score:
|
225 |
+
return QueryDomain.BIOMEDICAL, True, 0.6
|
226 |
+
else:
|
227 |
+
return QueryDomain.NON_BIOMEDICAL, False, 0.6
|
228 |
+
|
229 |
+
def _generate_rejection_message(self, query: str, domain: QueryDomain) -> str:
|
230 |
+
"""Generate a polite, helpful rejection message."""
|
231 |
+
|
232 |
+
base_message = """I'm designed specifically for biomedical and health-related questions. """
|
233 |
+
|
234 |
+
# Customize message based on what was detected
|
235 |
+
if any(pattern in query for patterns in self.non_biomedical_patterns.values() for pattern in patterns):
|
236 |
+
category_detected = next(
|
237 |
+
category for category, patterns in self.non_biomedical_patterns.items()
|
238 |
+
if any(pattern in query for pattern in patterns)
|
239 |
+
)
|
240 |
+
|
241 |
+
if category_detected == 'weather':
|
242 |
+
specific_message = "I can't help with weather information, but I'd be happy to answer questions about environmental health, climate-related diseases, or seasonal health patterns."
|
243 |
+
elif category_detected == 'sports':
|
244 |
+
specific_message = "I can't help with sports information, but I could discuss sports medicine, exercise physiology, or injury prevention."
|
245 |
+
elif category_detected == 'food':
|
246 |
+
specific_message = "I can't provide recipes, but I could help with nutrition science, food allergies, or dietary health research."
|
247 |
+
elif category_detected == 'technology':
|
248 |
+
specific_message = "I can't help with general technology, but I could discuss medical technology, bioinformatics, or health informatics."
|
249 |
+
else:
|
250 |
+
specific_message = "I'd be happy to help with biomedical research questions instead."
|
251 |
+
else:
|
252 |
+
specific_message = "I'd be happy to help with questions about genes, diseases, treatments, medications, or medical research."
|
253 |
+
|
254 |
+
examples = """
|
255 |
+
|
256 |
+
**I can help with questions like:**
|
257 |
+
• Gene information (e.g., "What are BRCA1 variants?")
|
258 |
+
• Disease research (e.g., "Latest treatments for diabetes")
|
259 |
+
• Drug interactions (e.g., "Side effects of metformin")
|
260 |
+
• Medical conditions (e.g., "Symptoms of Huntington's disease")
|
261 |
+
• Clinical research (e.g., "Recent cancer immunotherapy studies")"""
|
262 |
+
|
263 |
+
return base_message + specific_message + examples
|
264 |
+
|
265 |
+
def get_biomedical_suggestions(self, query: str) -> List[str]:
|
266 |
+
"""
|
267 |
+
Generate biomedical query suggestions based on a non-biomedical query.
|
268 |
+
|
269 |
+
This helps guide users toward appropriate biomedical questions.
|
270 |
+
"""
|
271 |
+
suggestions = []
|
272 |
+
query_lower = query.lower()
|
273 |
+
|
274 |
+
# Pattern-based suggestions
|
275 |
+
if 'heart' in query_lower:
|
276 |
+
suggestions.extend([
|
277 |
+
"What are the genetic factors in heart disease?",
|
278 |
+
"LDLR gene variants and cardiovascular risk",
|
279 |
+
"Latest research on cardiac medications"
|
280 |
+
])
|
281 |
+
elif 'brain' in query_lower:
|
282 |
+
suggestions.extend([
|
283 |
+
"What causes Alzheimer's disease?",
|
284 |
+
"APOE gene and dementia risk",
|
285 |
+
"Recent neurology research findings"
|
286 |
+
])
|
287 |
+
elif any(word in query_lower for word in ['food', 'eat', 'diet']):
|
288 |
+
suggestions.extend([
|
289 |
+
"Genetic factors in food allergies",
|
290 |
+
"Nutrition and gene expression",
|
291 |
+
"Dietary treatments for genetic disorders"
|
292 |
+
])
|
293 |
+
elif 'exercise' in query_lower or 'fitness' in query_lower:
|
294 |
+
suggestions.extend([
|
295 |
+
"Genetics of muscle development",
|
296 |
+
"Exercise and cardiovascular health",
|
297 |
+
"Sports medicine and injury prevention"
|
298 |
+
])
|
299 |
+
else:
|
300 |
+
# General biomedical suggestions
|
301 |
+
suggestions.extend([
|
302 |
+
"What are BRCA1 genetic variants?",
|
303 |
+
"Latest diabetes research findings",
|
304 |
+
"How does aspirin work medically?",
|
305 |
+
"What causes cancer at the molecular level?"
|
306 |
+
])
|
307 |
+
|
308 |
+
return suggestions[:3] # Return top 3 suggestions
|
309 |
+
|
310 |
+
|
311 |
+
# Global instance for easy import
|
312 |
+
biomedical_guardrails = BiomedicalGuardrails()
|
313 |
+
|
314 |
+
|
315 |
+
def validate_biomedical_query(query: str) -> GuardrailResult:
|
316 |
+
"""Convenience function for query validation."""
|
317 |
+
return biomedical_guardrails.validate_query(query)
|
gquery/src/gquery/agents/config.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Agent Configuration Module
|
3 |
+
|
4 |
+
Centralizes configuration for AI agents, LLM settings, and orchestration parameters.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from enum import Enum
|
11 |
+
|
12 |
+
|
13 |
+
class QueryType(Enum):
|
14 |
+
"""Types of queries the agent can handle."""
|
15 |
+
GENE_LOOKUP = "gene_lookup"
|
16 |
+
VARIANT_ANALYSIS = "variant_analysis"
|
17 |
+
LITERATURE_SEARCH = "literature_search"
|
18 |
+
CROSS_DATABASE = "cross_database"
|
19 |
+
SYNTHESIS = "synthesis"
|
20 |
+
|
21 |
+
|
22 |
+
class DatabasePriority(Enum):
|
23 |
+
"""Priority levels for database selection."""
|
24 |
+
HIGH = "high"
|
25 |
+
MEDIUM = "medium"
|
26 |
+
LOW = "low"
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class AgentConfig:
|
31 |
+
"""Configuration for AI agents."""
|
32 |
+
|
33 |
+
# OpenAI Settings
|
34 |
+
openai_api_key: str
|
35 |
+
model: str = "gpt-4o"
|
36 |
+
temperature: float = 0.1
|
37 |
+
max_tokens: int = 4000
|
38 |
+
timeout: int = 60
|
39 |
+
|
40 |
+
# Agent Behavior
|
41 |
+
max_retries: int = 3
|
42 |
+
confidence_threshold: float = 0.7
|
43 |
+
synthesis_depth: str = "moderate" # shallow, moderate, deep
|
44 |
+
|
45 |
+
# Database Integration
|
46 |
+
enable_caching: bool = True
|
47 |
+
cache_ttl: int = 3600 # 1 hour
|
48 |
+
concurrent_queries: int = 3
|
49 |
+
|
50 |
+
# Error Handling
|
51 |
+
fallback_enabled: bool = True
|
52 |
+
error_recovery_attempts: int = 2
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def from_env(cls) -> "AgentConfig":
|
56 |
+
"""Create configuration from environment variables."""
|
57 |
+
# Load .env file if it exists
|
58 |
+
try:
|
59 |
+
from dotenv import load_dotenv
|
60 |
+
load_dotenv()
|
61 |
+
except ImportError:
|
62 |
+
pass # dotenv not installed
|
63 |
+
|
64 |
+
return cls(
|
65 |
+
openai_api_key=os.getenv("OPENAI__API_KEY", ""),
|
66 |
+
model=os.getenv("OPENAI__MODEL", "gpt-4o"),
|
67 |
+
temperature=float(os.getenv("OPENAI__TEMPERATURE", "0.1")),
|
68 |
+
max_tokens=int(os.getenv("OPENAI__MAX_TOKENS", "4000")),
|
69 |
+
timeout=int(os.getenv("OPENAI__TIMEOUT", "60")),
|
70 |
+
max_retries=int(os.getenv("AGENT__MAX_RETRIES", "3")),
|
71 |
+
confidence_threshold=float(os.getenv("AGENT__CONFIDENCE_THRESHOLD", "0.7")),
|
72 |
+
synthesis_depth=os.getenv("AGENT__SYNTHESIS_DEPTH", "moderate"),
|
73 |
+
enable_caching=os.getenv("AGENT__ENABLE_CACHING", "true").lower() == "true",
|
74 |
+
cache_ttl=int(os.getenv("AGENT__CACHE_TTL", "3600")),
|
75 |
+
concurrent_queries=int(os.getenv("AGENT__CONCURRENT_QUERIES", "3")),
|
76 |
+
fallback_enabled=os.getenv("AGENT__FALLBACK_ENABLED", "true").lower() == "true",
|
77 |
+
error_recovery_attempts=int(os.getenv("AGENT__ERROR_RECOVERY_ATTEMPTS", "2"))
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
# Database priorities for different query types
|
82 |
+
DATABASE_PRIORITIES: Dict[QueryType, Dict[str, DatabasePriority]] = {
|
83 |
+
QueryType.GENE_LOOKUP: {
|
84 |
+
"datasets": DatabasePriority.HIGH,
|
85 |
+
"clinvar": DatabasePriority.MEDIUM,
|
86 |
+
"pmc": DatabasePriority.LOW
|
87 |
+
},
|
88 |
+
QueryType.VARIANT_ANALYSIS: {
|
89 |
+
"clinvar": DatabasePriority.HIGH,
|
90 |
+
"datasets": DatabasePriority.MEDIUM,
|
91 |
+
"pmc": DatabasePriority.MEDIUM
|
92 |
+
},
|
93 |
+
QueryType.LITERATURE_SEARCH: {
|
94 |
+
"pmc": DatabasePriority.HIGH,
|
95 |
+
"datasets": DatabasePriority.LOW,
|
96 |
+
"clinvar": DatabasePriority.LOW
|
97 |
+
},
|
98 |
+
QueryType.CROSS_DATABASE: {
|
99 |
+
"datasets": DatabasePriority.HIGH,
|
100 |
+
"clinvar": DatabasePriority.HIGH,
|
101 |
+
"pmc": DatabasePriority.HIGH
|
102 |
+
},
|
103 |
+
QueryType.SYNTHESIS: {
|
104 |
+
"datasets": DatabasePriority.HIGH,
|
105 |
+
"clinvar": DatabasePriority.HIGH,
|
106 |
+
"pmc": DatabasePriority.HIGH
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
# Prompts for different agent tasks
|
111 |
+
AGENT_PROMPTS = {
|
112 |
+
"query_analysis": """
|
113 |
+
You are a biomedical query analysis expert. Analyze the following user query and determine:
|
114 |
+
|
115 |
+
1. Query Type: What type of biological/medical query is this?
|
116 |
+
2. Entities: What genes, variants, diseases, or other biomedical entities are mentioned?
|
117 |
+
3. Databases: Which databases (PMC, ClinVar, Datasets) would be most relevant?
|
118 |
+
4. Intent: What is the user trying to accomplish?
|
119 |
+
5. Complexity: Is this a simple lookup or complex analysis?
|
120 |
+
|
121 |
+
CRITICAL RULE: For ANY query mentioning genes, variants, diseases, or treatments, you MUST include ALL THREE databases:
|
122 |
+
- PMC: ALWAYS include for literature, research papers, and clinical studies
|
123 |
+
- ClinVar: ALWAYS include for genetic variants and clinical significance
|
124 |
+
- Datasets: ALWAYS include for genomic datasets, expression data, and research data
|
125 |
+
|
126 |
+
OVERRIDE SINGLE-DATABASE THINKING: Even if the query seems to focus on one area, comprehensive biomedical research requires cross-database synthesis. Default to including ["pmc", "clinvar", "datasets"] unless the user explicitly requests a single database.
|
127 |
+
|
128 |
+
Query: {query}
|
129 |
+
|
130 |
+
Respond in JSON format with the following structure:
|
131 |
+
{{
|
132 |
+
"query_type": "cross_database",
|
133 |
+
"entities": {{
|
134 |
+
"genes": ["list of gene symbols/names"],
|
135 |
+
"variants": ["list of variants"],
|
136 |
+
"diseases": ["list of diseases/conditions"],
|
137 |
+
"organisms": ["list of organisms"],
|
138 |
+
"other": ["other relevant terms"]
|
139 |
+
}},
|
140 |
+
"databases_needed": ["pmc", "clinvar", "datasets"],
|
141 |
+
"intent": "brief description of user intent",
|
142 |
+
"complexity": "simple|moderate|complex",
|
143 |
+
"confidence": 0.0-1.0
|
144 |
+
}}
|
145 |
+
""",
|
146 |
+
|
147 |
+
"synthesis": """
|
148 |
+
You are a biomedical data synthesis expert working for NCBI. Given the following data from multiple databases,
|
149 |
+
provide a comprehensive informational synthesis that addresses the user's query.
|
150 |
+
|
151 |
+
IMPORTANT: NCBI is an information provider, NOT a recommender. Do not provide clinical recommendations,
|
152 |
+
treatment advice, or therapeutic suggestions. Focus solely on presenting the available scientific information.
|
153 |
+
|
154 |
+
Original Query: {query}
|
155 |
+
|
156 |
+
Data Sources:
|
157 |
+
{data_sources}
|
158 |
+
|
159 |
+
Instructions:
|
160 |
+
1. Synthesize findings across all data sources objectively
|
161 |
+
2. Identify key patterns and relationships in the data
|
162 |
+
3. Highlight any contradictions or gaps in the available information
|
163 |
+
4. Provide evidence-based factual statements about what the data shows
|
164 |
+
5. Note areas where information is limited or unavailable
|
165 |
+
|
166 |
+
Format your response as a structured analysis with:
|
167 |
+
- Executive Summary (factual overview of available information)
|
168 |
+
- Key Findings (what the data reveals)
|
169 |
+
- Cross-Database Correlations (connections between data sources)
|
170 |
+
- Data Limitations and Gaps (what information is missing or incomplete)
|
171 |
+
- Additional Information Sources (relevant NCBI resources for further investigation)
|
172 |
+
|
173 |
+
Remember: Present information objectively without making clinical recommendations or treatment suggestions.
|
174 |
+
""",
|
175 |
+
|
176 |
+
"entity_resolution": """
|
177 |
+
You are a biomedical entity resolution expert. Given the following entities extracted from a query,
|
178 |
+
provide standardized identifiers and resolve any ambiguities.
|
179 |
+
|
180 |
+
Entities: {entities}
|
181 |
+
|
182 |
+
For each entity, provide:
|
183 |
+
1. Standardized name/symbol
|
184 |
+
2. Database identifiers (Gene ID, HGNC, etc.)
|
185 |
+
3. Alternative names/synonyms
|
186 |
+
4. Organism information
|
187 |
+
5. Confidence in resolution
|
188 |
+
|
189 |
+
Respond in JSON format with resolved entities.
|
190 |
+
"""
|
191 |
+
}
|
gquery/src/gquery/agents/enhanced_orchestrator.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Enhanced Agent Orchestration for GQuery POC - UPDATED WITH IMPROVED PROMPTS
|
3 |
+
|
4 |
+
Implements the core workflow:
|
5 |
+
1. Simple query processing (1-3 words max)
|
6 |
+
2. Clarification flow for ambiguous queries
|
7 |
+
3. Parallel database workers (3 agents) - REAL API CALLS
|
8 |
+
4. Scientific writer agent with ENHANCED PROMPTS
|
9 |
+
5. Conversation memory & context
|
10 |
+
6. Source attribution
|
11 |
+
7. LangSmith observability
|
12 |
+
|
13 |
+
Feature 10: Enhanced Prompt Engineering Implementation
|
14 |
+
- Improved query classification with few-shot examples
|
15 |
+
- Better database selection strategies
|
16 |
+
- Enhanced synthesis prompts for higher quality responses
|
17 |
+
- Smarter follow-up suggestions
|
18 |
+
"""
|
19 |
+
|
20 |
+
import asyncio
|
21 |
+
import logging
|
22 |
+
from typing import Dict, List, Optional, Any, TypedDict, Tuple
|
23 |
+
from datetime import datetime
|
24 |
+
from dataclasses import dataclass
|
25 |
+
from enum import Enum
|
26 |
+
|
27 |
+
# LangSmith tracing
|
28 |
+
from langsmith import Client, traceable
|
29 |
+
from langsmith.run_helpers import trace
|
30 |
+
|
31 |
+
from .biomedical_guardrails import BiomedicalGuardrails, GuardrailResult, QueryDomain
|
32 |
+
|
33 |
+
# Import REAL API clients
|
34 |
+
from ..tools.pmc_client import PMCClient
|
35 |
+
from ..tools.clinvar_client import ClinVarClient
|
36 |
+
from ..tools.datasets_client import DatasetsClient
|
37 |
+
|
38 |
+
# Import enhanced prompts (Feature 10)
|
39 |
+
import sys
|
40 |
+
import os
|
41 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
42 |
+
# Enhanced prompts are now built into this module (Feature 10)
|
43 |
+
ENHANCED_PROMPTS_AVAILABLE = True
|
44 |
+
print("✅ Enhanced prompts loaded for Feature 10")
|
45 |
+
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
class QueryType(Enum):
|
51 |
+
"""Types of queries the system can handle."""
|
52 |
+
GENE = "gene"
|
53 |
+
DISEASE = "disease"
|
54 |
+
DRUG = "drug"
|
55 |
+
PROTEIN = "protein"
|
56 |
+
PATHWAY = "pathway"
|
57 |
+
AMBIGUOUS = "ambiguous"
|
58 |
+
UNCLEAR = "unclear"
|
59 |
+
|
60 |
+
|
61 |
+
@dataclass
|
62 |
+
class ConversationMemory:
|
63 |
+
"""Maintains conversation context and history."""
|
64 |
+
messages: List[Dict[str, str]]
|
65 |
+
query_history: List[str]
|
66 |
+
current_topic: Optional[str] = None
|
67 |
+
clarifications_needed: List[str] = None
|
68 |
+
user_preferences: Dict[str, Any] = None
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class DatabaseResult:
|
73 |
+
"""Result from a single database worker."""
|
74 |
+
database: str
|
75 |
+
query: str
|
76 |
+
results: List[Dict]
|
77 |
+
total_count: int
|
78 |
+
sources: List[str]
|
79 |
+
processing_time_ms: int
|
80 |
+
success: bool
|
81 |
+
error: Optional[str] = None
|
82 |
+
|
83 |
+
|
84 |
+
@dataclass
|
85 |
+
class ClarificationRequest:
|
86 |
+
"""Request for query clarification."""
|
87 |
+
original_query: str
|
88 |
+
ambiguity_type: str
|
89 |
+
clarification_question: str
|
90 |
+
suggested_options: List[str]
|
91 |
+
confidence: float
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class ScientificSynthesis:
|
96 |
+
"""Synthesized scientific response."""
|
97 |
+
response: str
|
98 |
+
sources: List[str]
|
99 |
+
confidence: float
|
100 |
+
methodology: str
|
101 |
+
limitations: str
|
102 |
+
follow_up_suggestions: List[str]
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class OrchestrationResult:
|
107 |
+
"""Complete result from orchestration."""
|
108 |
+
original_query: str
|
109 |
+
final_response: str
|
110 |
+
sources: List[str]
|
111 |
+
query_classification: QueryType
|
112 |
+
clarification_used: Optional[ClarificationRequest]
|
113 |
+
database_results: Dict[str, DatabaseResult]
|
114 |
+
synthesis: ScientificSynthesis
|
115 |
+
conversation_memory: ConversationMemory
|
116 |
+
execution_time_ms: int
|
117 |
+
success: bool
|
118 |
+
errors: List[str]
|
119 |
+
observability_trace_id: Optional[str]
|
120 |
+
|
121 |
+
|
122 |
+
class EnhancedGQueryOrchestrator:
|
123 |
+
"""
|
124 |
+
Enhanced orchestrator implementing the core POC workflow:
|
125 |
+
Query -> Clarify -> 3 Database Workers (REAL APIs) -> Scientific Writer -> Response
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self):
|
129 |
+
self.guardrails = BiomedicalGuardrails()
|
130 |
+
self.langsmith_client = None
|
131 |
+
try:
|
132 |
+
# Ensure environment is loaded for keys like LANGSMITH_API_KEY, LANGSMITH_TRACING
|
133 |
+
import os
|
134 |
+
if os.getenv("LANGSMITH_API_KEY"):
|
135 |
+
self.langsmith_client = Client()
|
136 |
+
logger.info("LangSmith tracing enabled")
|
137 |
+
else:
|
138 |
+
logger.info("LangSmith API key not set; tracing disabled")
|
139 |
+
except Exception as e:
|
140 |
+
logger.warning(f"LangSmith not available: {e}")
|
141 |
+
|
142 |
+
# Initialize REAL API clients
|
143 |
+
self.pmc_client = PMCClient()
|
144 |
+
self.clinvar_client = ClinVarClient()
|
145 |
+
self.datasets_client = DatasetsClient()
|
146 |
+
|
147 |
+
# Conversation memory storage
|
148 |
+
self.conversations: Dict[str, ConversationMemory] = {}
|
149 |
+
|
150 |
+
logger.info("Enhanced orchestrator initialized with REAL API clients")
|
151 |
+
|
152 |
+
@traceable(run_type="chain", name="gquery_orchestration")
|
153 |
+
async def process_query(
|
154 |
+
self,
|
155 |
+
query: str,
|
156 |
+
session_id: str = "default",
|
157 |
+
conversation_history: List[Dict] = None
|
158 |
+
) -> OrchestrationResult:
|
159 |
+
"""
|
160 |
+
Main orchestration flow:
|
161 |
+
1. Validate biomedical query
|
162 |
+
2. Classify and clarify if needed
|
163 |
+
3. Run 3 database workers in parallel
|
164 |
+
4. Synthesize with scientific writer
|
165 |
+
5. Update conversation memory
|
166 |
+
"""
|
167 |
+
start_time = datetime.now()
|
168 |
+
trace_id = None
|
169 |
+
|
170 |
+
try:
|
171 |
+
# Initialize or get conversation memory
|
172 |
+
if session_id not in self.conversations:
|
173 |
+
self.conversations[session_id] = ConversationMemory(
|
174 |
+
messages=[],
|
175 |
+
query_history=[],
|
176 |
+
user_preferences={}
|
177 |
+
)
|
178 |
+
|
179 |
+
memory = self.conversations[session_id]
|
180 |
+
|
181 |
+
# Step 1: Biomedical Guardrails Validation
|
182 |
+
with trace(name="biomedical_validation"):
|
183 |
+
guardrail_result = self.guardrails.validate_query(query)
|
184 |
+
|
185 |
+
if not guardrail_result.is_valid:
|
186 |
+
return self._create_rejection_result(query, guardrail_result, start_time)
|
187 |
+
|
188 |
+
# Step 2: Simple Query Classification (1-3 words -> always clarify)
|
189 |
+
with trace(name="query_classification"):
|
190 |
+
query_type, needs_clarification = self._classify_simple_query(query, memory)
|
191 |
+
|
192 |
+
# Step 3: Clarification Flow (if needed) — return early with options, do NOT assume
|
193 |
+
clarification_request = None
|
194 |
+
if needs_clarification:
|
195 |
+
with trace(name="clarification_generation"):
|
196 |
+
clarification_request = self._generate_clarification(query, query_type, memory)
|
197 |
+
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
198 |
+
return OrchestrationResult(
|
199 |
+
original_query=query,
|
200 |
+
final_response=clarification_request.clarification_question,
|
201 |
+
sources=[],
|
202 |
+
query_classification=query_type,
|
203 |
+
clarification_used=clarification_request,
|
204 |
+
database_results={},
|
205 |
+
synthesis=ScientificSynthesis(
|
206 |
+
response=clarification_request.clarification_question,
|
207 |
+
sources=[],
|
208 |
+
confidence=0.0,
|
209 |
+
methodology="clarification",
|
210 |
+
limitations="awaiting_user_input",
|
211 |
+
follow_up_suggestions=clarification_request.suggested_options,
|
212 |
+
),
|
213 |
+
conversation_memory=memory,
|
214 |
+
execution_time_ms=execution_time,
|
215 |
+
success=True,
|
216 |
+
errors=[],
|
217 |
+
observability_trace_id=trace_id,
|
218 |
+
)
|
219 |
+
|
220 |
+
# Step 4: Parallel Database Workers
|
221 |
+
with trace(name="database_workers"):
|
222 |
+
database_results = await self._run_database_workers(query, query_type)
|
223 |
+
|
224 |
+
# Step 5: Scientific Writer Synthesis
|
225 |
+
with trace(name="scientific_synthesis"):
|
226 |
+
synthesis = await self._synthesize_scientific_response(
|
227 |
+
query, query_type, database_results, memory
|
228 |
+
)
|
229 |
+
|
230 |
+
# Step 6: Update Conversation Memory
|
231 |
+
self._update_conversation_memory(memory, query, synthesis)
|
232 |
+
|
233 |
+
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
234 |
+
|
235 |
+
return OrchestrationResult(
|
236 |
+
original_query=query,
|
237 |
+
final_response=synthesis.response,
|
238 |
+
sources=synthesis.sources,
|
239 |
+
query_classification=query_type,
|
240 |
+
clarification_used=clarification_request,
|
241 |
+
database_results={db: result for db, result in database_results.items()},
|
242 |
+
synthesis=synthesis,
|
243 |
+
conversation_memory=memory,
|
244 |
+
execution_time_ms=execution_time,
|
245 |
+
success=True,
|
246 |
+
errors=[],
|
247 |
+
observability_trace_id=trace_id
|
248 |
+
)
|
249 |
+
|
250 |
+
except Exception as e:
|
251 |
+
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
252 |
+
logger.error(f"Orchestration failed: {e}")
|
253 |
+
|
254 |
+
return OrchestrationResult(
|
255 |
+
original_query=query,
|
256 |
+
final_response=f"I encountered an error processing your query: {str(e)}",
|
257 |
+
sources=[],
|
258 |
+
query_classification=QueryType.UNCLEAR,
|
259 |
+
clarification_used=None,
|
260 |
+
database_results={},
|
261 |
+
synthesis=ScientificSynthesis(
|
262 |
+
response=f"Error: {str(e)}",
|
263 |
+
sources=[],
|
264 |
+
confidence=0.0,
|
265 |
+
methodology="error_handling",
|
266 |
+
limitations="System error occurred",
|
267 |
+
follow_up_suggestions=[]
|
268 |
+
),
|
269 |
+
conversation_memory=self.conversations.get(session_id, ConversationMemory([], [])),
|
270 |
+
execution_time_ms=execution_time,
|
271 |
+
success=False,
|
272 |
+
errors=[str(e)],
|
273 |
+
observability_trace_id=trace_id
|
274 |
+
)
|
275 |
+
|
276 |
+
def _classify_simple_query(self, query: str, memory: ConversationMemory) -> Tuple[QueryType, bool]:
|
277 |
+
"""
|
278 |
+
Classify queries and enforce clarification for short inputs (<= 3 words).
|
279 |
+
"""
|
280 |
+
words = query.lower().strip().split()
|
281 |
+
|
282 |
+
# Basic heuristics for type inference (still require clarification if short)
|
283 |
+
inferred: QueryType = QueryType.UNCLEAR
|
284 |
+
lower_q = query.lower()
|
285 |
+
if any(pattern in lower_q for pattern in ['brca1', 'brca2', 'tp53', 'cftr', 'apoe', 'mthfr', 'vegf', 'egfr']):
|
286 |
+
inferred = QueryType.GENE
|
287 |
+
elif any(pattern in lower_q for pattern in ['diabetes', 'cancer', 'alzheimer', 'parkinsons', 'hypertension', 'tuberculosis']):
|
288 |
+
inferred = QueryType.DISEASE
|
289 |
+
elif any(pattern in lower_q for pattern in ['aspirin', 'metformin', 'insulin', 'warfarin', 'statin']):
|
290 |
+
inferred = QueryType.DRUG
|
291 |
+
|
292 |
+
# Enforce clarification for 1-3 word inputs
|
293 |
+
if len(words) <= 3:
|
294 |
+
return inferred if inferred != QueryType.UNCLEAR else QueryType.AMBIGUOUS, True
|
295 |
+
|
296 |
+
# Longer queries proceed without clarification
|
297 |
+
return inferred if inferred != QueryType.UNCLEAR else QueryType.UNCLEAR, False
|
298 |
+
|
299 |
+
def _generate_clarification(
|
300 |
+
self,
|
301 |
+
query: str,
|
302 |
+
query_type: QueryType,
|
303 |
+
memory: ConversationMemory
|
304 |
+
) -> ClarificationRequest:
|
305 |
+
"""Generate clarification questions for ambiguous queries."""
|
306 |
+
|
307 |
+
word = query.lower().strip()
|
308 |
+
|
309 |
+
clarifications = {
|
310 |
+
'heart': {
|
311 |
+
'question': "I can help with heart-related biomedical topics. What specifically would you like to know?",
|
312 |
+
'options': [
|
313 |
+
f"Gene information about {query}",
|
314 |
+
f"Disease research on {query}",
|
315 |
+
f"Drug/treatment information for {query}"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
'cell': {
|
319 |
+
'question': "Are you asking about biological cells? What aspect interests you?",
|
320 |
+
'options': [
|
321 |
+
f"Cell biology of {query}",
|
322 |
+
f"Stem cells related to {query}",
|
323 |
+
f"Cancer cell research on {query}"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
'gene': {
|
327 |
+
'question': "Which gene or genetic topic would you like to explore?",
|
328 |
+
'options': [
|
329 |
+
f"Specific gene variants for {query}",
|
330 |
+
f"Gene therapy related to {query}",
|
331 |
+
f"Genetic testing about {query}"
|
332 |
+
]
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
if word in clarifications:
|
337 |
+
clarif = clarifications[word]
|
338 |
+
return ClarificationRequest(
|
339 |
+
original_query=query,
|
340 |
+
ambiguity_type="single_word",
|
341 |
+
clarification_question=clarif['question'],
|
342 |
+
suggested_options=clarif['options'],
|
343 |
+
confidence=0.8
|
344 |
+
)
|
345 |
+
|
346 |
+
# Generic clarification for unclear queries — embed query in options to avoid infinite clarification loops
|
347 |
+
return ClarificationRequest(
|
348 |
+
original_query=query,
|
349 |
+
ambiguity_type="unclear",
|
350 |
+
clarification_question="Could you be more specific about what biomedical information you're looking for?",
|
351 |
+
suggested_options=[
|
352 |
+
f"Gene information about {query}",
|
353 |
+
f"Disease research on {query}",
|
354 |
+
f"Drug/treatment information for {query}"
|
355 |
+
],
|
356 |
+
confidence=0.6
|
357 |
+
)
|
358 |
+
|
359 |
+
async def _run_database_workers(
|
360 |
+
self,
|
361 |
+
query: str,
|
362 |
+
query_type: QueryType
|
363 |
+
) -> Dict[str, DatabaseResult]:
|
364 |
+
"""Run 3 database workers in parallel with fresh client initialization."""
|
365 |
+
|
366 |
+
try:
|
367 |
+
# Initialize fresh clients for each query to avoid session issues
|
368 |
+
logger.info("Initializing fresh API clients")
|
369 |
+
self.datasets_client = DatasetsClient()
|
370 |
+
self.clinvar_client = ClinVarClient()
|
371 |
+
|
372 |
+
# Create tasks for parallel execution
|
373 |
+
tasks = [
|
374 |
+
self._datasets_worker(query, query_type),
|
375 |
+
self._pmc_worker(query, query_type),
|
376 |
+
self._clinvar_worker(query, query_type)
|
377 |
+
]
|
378 |
+
|
379 |
+
# Run all workers in parallel
|
380 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
381 |
+
|
382 |
+
return {
|
383 |
+
'datasets': results[0] if not isinstance(results[0], Exception) else self._create_error_result('datasets', results[0]),
|
384 |
+
'pmc': results[1] if not isinstance(results[1], Exception) else self._create_error_result('pmc', results[1]),
|
385 |
+
'clinvar': results[2] if not isinstance(results[2], Exception) else self._create_error_result('clinvar', results[2])
|
386 |
+
}
|
387 |
+
except Exception as e:
|
388 |
+
logger.error(f"Error in parallel database query: {e}")
|
389 |
+
return {
|
390 |
+
'datasets': self._create_error_result('datasets', e),
|
391 |
+
'pmc': self._create_error_result('pmc', e),
|
392 |
+
'clinvar': self._create_error_result('clinvar', e)
|
393 |
+
}
|
394 |
+
|
395 |
+
async def _datasets_worker(self, query: str, query_type: QueryType) -> DatabaseResult:
|
396 |
+
"""NCBI Datasets database worker - REAL API CALLS."""
|
397 |
+
start_time = datetime.now()
|
398 |
+
|
399 |
+
try:
|
400 |
+
logger.info(f"Datasets API call for query: {query} (type: {query_type})")
|
401 |
+
|
402 |
+
# Make REAL API call to NCBI Datasets with proper session management
|
403 |
+
async with self.datasets_client:
|
404 |
+
datasets_genes = await self.datasets_client.search_genes(
|
405 |
+
query=query,
|
406 |
+
limit=10
|
407 |
+
)
|
408 |
+
|
409 |
+
logger.info(f"Datasets API returned {len(datasets_genes) if datasets_genes else 0} genes")
|
410 |
+
|
411 |
+
processing_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
412 |
+
|
413 |
+
# Convert API response to our format
|
414 |
+
results = []
|
415 |
+
if datasets_genes:
|
416 |
+
for gene in datasets_genes[:5]: # Limit to top 5 results
|
417 |
+
results.append({
|
418 |
+
"gene_symbol": getattr(gene, 'symbol', None),
|
419 |
+
"gene_id": getattr(gene, 'gene_id', None),
|
420 |
+
"description": getattr(gene, 'description', None),
|
421 |
+
"chromosome": getattr(gene, 'chromosome', None),
|
422 |
+
"organism": getattr(gene, 'organism_name', None),
|
423 |
+
"type": "gene_data"
|
424 |
+
})
|
425 |
+
|
426 |
+
sources = [
|
427 |
+
f"https://www.ncbi.nlm.nih.gov/gene/{getattr(g, 'gene_id', None)}"
|
428 |
+
for g in (datasets_genes[:3] if datasets_genes else []) if getattr(g, 'gene_id', None)
|
429 |
+
]
|
430 |
+
|
431 |
+
return DatabaseResult(
|
432 |
+
database="NCBI Datasets",
|
433 |
+
query=query,
|
434 |
+
results=results,
|
435 |
+
total_count=len(results),
|
436 |
+
sources=sources,
|
437 |
+
processing_time_ms=processing_time,
|
438 |
+
success=True
|
439 |
+
)
|
440 |
+
|
441 |
+
except Exception as e:
|
442 |
+
logger.error(f"Datasets API error: {e}")
|
443 |
+
return self._create_error_result('datasets', e)
|
444 |
+
|
445 |
+
async def _pmc_worker(self, query: str, query_type: QueryType) -> DatabaseResult:
|
446 |
+
"""PubMed Central worker - REAL API CALLS."""
|
447 |
+
start_time = datetime.now()
|
448 |
+
|
449 |
+
try:
|
450 |
+
logger.info(f"PMC API call for query: {query}")
|
451 |
+
|
452 |
+
# Make REAL API call to PubMed Central
|
453 |
+
async with self.pmc_client:
|
454 |
+
pmc_response = await self.pmc_client.search_articles(
|
455 |
+
query=query,
|
456 |
+
max_results=10,
|
457 |
+
filters=None # Could add biomedical filters here
|
458 |
+
)
|
459 |
+
|
460 |
+
processing_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
461 |
+
|
462 |
+
# Convert API response to our format
|
463 |
+
results = []
|
464 |
+
if pmc_response and pmc_response.results:
|
465 |
+
for search_result in pmc_response.results[:5]: # Top 5 results
|
466 |
+
article = search_result.article
|
467 |
+
results.append({
|
468 |
+
"title": article.title,
|
469 |
+
"pmcid": article.pmc_id,
|
470 |
+
"pmid": article.pmid,
|
471 |
+
"authors": article.authors[:3] if article.authors else [], # First 3 authors
|
472 |
+
"journal": article.journal,
|
473 |
+
"year": article.publication_date.year if article.publication_date else None,
|
474 |
+
"abstract": article.abstract[:200] + "..." if article.abstract and len(article.abstract) > 200 else article.abstract,
|
475 |
+
"type": "research_article"
|
476 |
+
})
|
477 |
+
|
478 |
+
sources = [f"https://www.ncbi.nlm.nih.gov/pmc/articles/{search_result.article.pmc_id}/"
|
479 |
+
for search_result in (pmc_response.results[:3] if pmc_response and pmc_response.results else [])]
|
480 |
+
|
481 |
+
return DatabaseResult(
|
482 |
+
database="PubMed Central",
|
483 |
+
query=query,
|
484 |
+
results=results,
|
485 |
+
total_count=len(results),
|
486 |
+
sources=sources,
|
487 |
+
processing_time_ms=processing_time,
|
488 |
+
success=True
|
489 |
+
)
|
490 |
+
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(f"PMC API error: {e}")
|
493 |
+
return self._create_error_result('pmc', e)
|
494 |
+
|
495 |
+
async def _clinvar_worker(self, query: str, query_type: QueryType) -> DatabaseResult:
|
496 |
+
"""ClinVar database worker - REAL API CALLS."""
|
497 |
+
start_time = datetime.now()
|
498 |
+
|
499 |
+
try:
|
500 |
+
# Query ClinVar for genes and diseases (expanded scope)
|
501 |
+
if query_type not in [QueryType.GENE, QueryType.PROTEIN, QueryType.DISEASE]:
|
502 |
+
return DatabaseResult(
|
503 |
+
database="ClinVar",
|
504 |
+
query=query,
|
505 |
+
results=[],
|
506 |
+
total_count=0,
|
507 |
+
sources=[],
|
508 |
+
processing_time_ms=0,
|
509 |
+
success=True,
|
510 |
+
error="Not applicable for this query type"
|
511 |
+
)
|
512 |
+
|
513 |
+
logger.info(f"ClinVar API call for query: {query}")
|
514 |
+
|
515 |
+
# Make REAL API call to ClinVar with proper session management
|
516 |
+
async with self.clinvar_client:
|
517 |
+
if query_type in [QueryType.GENE, QueryType.PROTEIN]:
|
518 |
+
clinvar_response = await self.clinvar_client.search_variants_by_gene(
|
519 |
+
gene_symbol=query,
|
520 |
+
max_results=10
|
521 |
+
)
|
522 |
+
else:
|
523 |
+
# For diseases, extract the disease name and search for disease-associated variants
|
524 |
+
disease_name = query.split()[0] if 'diabetes' in query.lower() else query.split()[-1]
|
525 |
+
if 'diabetes' in query.lower():
|
526 |
+
disease_name = 'diabetes'
|
527 |
+
elif 'cancer' in query.lower():
|
528 |
+
disease_name = 'cancer'
|
529 |
+
elif 'alzheimer' in query.lower():
|
530 |
+
disease_name = 'alzheimer'
|
531 |
+
|
532 |
+
clinvar_response = await self.clinvar_client.search_variant_by_name(
|
533 |
+
variant_name=disease_name,
|
534 |
+
max_results=10
|
535 |
+
)
|
536 |
+
|
537 |
+
processing_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
538 |
+
|
539 |
+
# Convert API response to our format - clinvar_response is a List[ClinVarVariant]
|
540 |
+
results = []
|
541 |
+
if clinvar_response:
|
542 |
+
for variant in clinvar_response[:5]: # Top 5 variants
|
543 |
+
results.append({
|
544 |
+
"variation_id": variant.variation_id,
|
545 |
+
"gene_symbol": variant.gene_symbol,
|
546 |
+
"hgvs": variant.hgvs_genomic or variant.hgvs_coding or variant.hgvs_protein,
|
547 |
+
"clinical_significance": getattr(variant.clinical_significance, 'value', variant.clinical_significance) if variant.clinical_significance else "Unknown",
|
548 |
+
"review_status": getattr(variant.review_status, 'value', variant.review_status) if variant.review_status else "Unknown",
|
549 |
+
"condition": variant.name,
|
550 |
+
"last_evaluated": variant.last_evaluated.isoformat() if variant.last_evaluated else None,
|
551 |
+
"type": "genetic_variant"
|
552 |
+
})
|
553 |
+
|
554 |
+
sources = [f"https://www.ncbi.nlm.nih.gov/clinvar/variation/{variant.variation_id}/"
|
555 |
+
for variant in (clinvar_response[:3] if clinvar_response else [])]
|
556 |
+
|
557 |
+
return DatabaseResult(
|
558 |
+
database="ClinVar",
|
559 |
+
query=query,
|
560 |
+
results=results,
|
561 |
+
total_count=len(results),
|
562 |
+
sources=sources,
|
563 |
+
processing_time_ms=processing_time,
|
564 |
+
success=True
|
565 |
+
)
|
566 |
+
|
567 |
+
except Exception as e:
|
568 |
+
logger.error(f"ClinVar API error: {e}")
|
569 |
+
return self._create_error_result('clinvar', e)
|
570 |
+
|
571 |
+
def _create_error_result(self, database: str, error: Exception) -> DatabaseResult:
|
572 |
+
"""Create error result for failed database worker."""
|
573 |
+
return DatabaseResult(
|
574 |
+
database=database,
|
575 |
+
query="",
|
576 |
+
results=[],
|
577 |
+
total_count=0,
|
578 |
+
sources=[],
|
579 |
+
processing_time_ms=0,
|
580 |
+
success=False,
|
581 |
+
error=str(error)
|
582 |
+
)
|
583 |
+
|
584 |
+
async def _synthesize_scientific_response(
|
585 |
+
self,
|
586 |
+
query: str,
|
587 |
+
query_type: QueryType,
|
588 |
+
database_results: Dict[str, DatabaseResult],
|
589 |
+
memory: ConversationMemory
|
590 |
+
) -> ScientificSynthesis:
|
591 |
+
"""
|
592 |
+
Scientific writer agent that synthesizes results into expert communication.
|
593 |
+
"""
|
594 |
+
start_time = datetime.now()
|
595 |
+
|
596 |
+
try:
|
597 |
+
# Collect all successful results and sources
|
598 |
+
all_sources = []
|
599 |
+
result_summaries = []
|
600 |
+
|
601 |
+
for db_name, result in database_results.items():
|
602 |
+
if result.success and result.results:
|
603 |
+
all_sources.extend(result.sources)
|
604 |
+
result_summaries.append(f"**{result.database}** ({result.total_count} results)")
|
605 |
+
|
606 |
+
# Generate scientific synthesis based on query type
|
607 |
+
if query_type == QueryType.GENE:
|
608 |
+
response = self._synthesize_gene_response(query, database_results, result_summaries)
|
609 |
+
elif query_type == QueryType.DISEASE:
|
610 |
+
response = self._synthesize_disease_response(query, database_results, result_summaries)
|
611 |
+
elif query_type == QueryType.DRUG:
|
612 |
+
response = self._synthesize_drug_response(query, database_results, result_summaries)
|
613 |
+
else:
|
614 |
+
response = self._synthesize_general_response(query, database_results, result_summaries)
|
615 |
+
|
616 |
+
# Generate follow-up suggestions
|
617 |
+
follow_ups = self._generate_follow_up_suggestions(query, query_type)
|
618 |
+
|
619 |
+
# Add source citations to response
|
620 |
+
if all_sources:
|
621 |
+
formatted_sources = self._format_source_citations(all_sources)
|
622 |
+
response += f"\n\n**📚 Sources:** {formatted_sources}"
|
623 |
+
|
624 |
+
return ScientificSynthesis(
|
625 |
+
response=response,
|
626 |
+
sources=list(set(all_sources)), # Remove duplicates
|
627 |
+
confidence=0.85,
|
628 |
+
methodology="Multi-database synthesis with scientific expertise",
|
629 |
+
limitations="Results are synthesized from available databases and may not be exhaustive",
|
630 |
+
follow_up_suggestions=follow_ups
|
631 |
+
)
|
632 |
+
|
633 |
+
except Exception as e:
|
634 |
+
return ScientificSynthesis(
|
635 |
+
response=f"I encountered an issue synthesizing the results: {str(e)}",
|
636 |
+
sources=[],
|
637 |
+
confidence=0.0,
|
638 |
+
methodology="error_handling",
|
639 |
+
limitations="Synthesis failed due to system error",
|
640 |
+
follow_up_suggestions=[]
|
641 |
+
)
|
642 |
+
|
643 |
+
def _format_source_citations(self, sources: List[str]) -> str:
|
644 |
+
"""Format sources as clickable citations."""
|
645 |
+
citations = []
|
646 |
+
for i, source in enumerate(sources[:10], 1): # Limit to 10 sources
|
647 |
+
if 'pmc' in source.lower():
|
648 |
+
citations.append(f'<a href="{source}" target="_blank" class="source-link">[{i}] PMC</a>')
|
649 |
+
elif 'clinvar' in source.lower():
|
650 |
+
citations.append(f'<a href="{source}" target="_blank" class="source-link">[{i}] ClinVar</a>')
|
651 |
+
elif 'datasets' in source.lower() or 'gene' in source.lower():
|
652 |
+
citations.append(f'<a href="{source}" target="_blank" class="source-link">[{i}] NCBI</a>')
|
653 |
+
else:
|
654 |
+
citations.append(f'<a href="{source}" target="_blank" class="source-link">[{i}] Source</a>')
|
655 |
+
return " ".join(citations)
|
656 |
+
|
657 |
+
def _synthesize_gene_response(self, gene: str, results: Dict, summaries: List[str]) -> str:
|
658 |
+
"""Enhanced synthesis for gene queries using improved prompts (Feature 10)."""
|
659 |
+
if True: # Always use enhanced prompts
|
660 |
+
# Use enhanced synthesis approach
|
661 |
+
return f"""🧬 **{gene.upper()} Gene Analysis**
|
662 |
+
|
663 |
+
**🔬 Functional Significance & Molecular Biology:**
|
664 |
+
The {gene} gene encodes a protein with critical roles in cellular function and human health. Understanding its biology involves:
|
665 |
+
|
666 |
+
• **Primary Function**: This gene controls essential cellular processes including signal transduction, metabolic regulation, DNA repair, or cell cycle control
|
667 |
+
• **Protein Structure**: The encoded protein contains functional domains that enable specific molecular interactions and enzymatic activities
|
668 |
+
• **Cellular Localization**: Protein products are found in specific cellular compartments (nucleus, mitochondria, membrane) where they perform their functions
|
669 |
+
• **Regulatory Networks**: {gene} participates in complex regulatory cascades involving transcription factors, microRNAs, and epigenetic modifications
|
670 |
+
|
671 |
+
**📊 Comprehensive Data Sources:**
|
672 |
+
{chr(10).join(f"• {summary}" for summary in summaries)}
|
673 |
+
|
674 |
+
**🎯 Key Research Findings & Evidence:**
|
675 |
+
• **Genomic Data**: {results.get('datasets', type('', (), {'total_count': 0})).total_count} comprehensive datasets provide expression profiles, splice variants, and functional annotations across tissues and conditions
|
676 |
+
• **Scientific Literature**: {results.get('pmc', type('', (), {'total_count': 0})).total_count} peer-reviewed publications document molecular mechanisms, disease associations, and therapeutic research
|
677 |
+
• **Clinical Variants**: {results.get('clinvar', type('', (), {'total_count': 0})).total_count} documented genetic variants with detailed pathogenicity assessments and clinical interpretations
|
678 |
+
|
679 |
+
**🧬 Genetic Variants & Clinical Impact:**
|
680 |
+
• **Pathogenic Variants**: Disease-causing mutations affect protein function through various mechanisms including loss of function, gain of function, or dominant negative effects
|
681 |
+
• **Population Genetics**: Allele frequencies vary across ethnic groups, influencing disease risk and genetic counseling approaches
|
682 |
+
• **Functional Studies**: Laboratory experiments demonstrate how specific variants alter protein activity, stability, or interactions
|
683 |
+
• **Genotype-Phenotype Correlations**: Clinical studies reveal relationships between specific mutations and disease severity or phenotypic features
|
684 |
+
|
685 |
+
**🧪 Clinical Relevance & Applications:**
|
686 |
+
Research on {gene} encompasses multiple clinical domains:
|
687 |
+
• **Disease Mechanisms**: Understanding how gene dysfunction contributes to pathological processes and disease progression
|
688 |
+
• **Diagnostic Applications**: Development of genetic tests for early detection, carrier screening, and confirmatory diagnosis
|
689 |
+
• **Therapeutic Targets**: Investigation of gene products as potential drug targets for precision medicine approaches
|
690 |
+
• **Biomarker Development**: Expression levels and variant status serve as prognostic and predictive biomarkers
|
691 |
+
• **Pharmacogenomics**: Genetic variants influence drug metabolism, efficacy, and adverse reaction profiles
|
692 |
+
|
693 |
+
**🔬 Current Research Frontiers:**
|
694 |
+
• **Functional Genomics**: CRISPR-based studies reveal gene function in development, disease, and therapeutic response
|
695 |
+
• **Single-Cell Analysis**: Cell-type-specific expression patterns provide insights into tissue-specific functions
|
696 |
+
• **Structural Biology**: Protein structure determination enables rational drug design and functional prediction
|
697 |
+
• **Systems Biology**: Integration with multi-omics data reveals broader biological networks and pathway interactions
|
698 |
+
• **Clinical Translation**: Ongoing clinical trials test gene-targeted therapies and diagnostic applications
|
699 |
+
|
700 |
+
**⚠️ Important Note:**
|
701 |
+
This information is synthesized from research databases for scientific purposes. Medical decisions should always involve healthcare professionals."""
|
702 |
+
else:
|
703 |
+
# Fallback to original synthesis
|
704 |
+
return f"""🧬 **{gene.upper()} Gene Information**
|
705 |
+
|
706 |
+
Based on current biomedical databases, here's what I found about {gene}:
|
707 |
+
|
708 |
+
**📊 Data Sources:**
|
709 |
+
{chr(10).join(f"• {summary}" for summary in summaries)}
|
710 |
+
|
711 |
+
**🔬 Key Findings:**
|
712 |
+
• **Genomic Data**: Found {results['datasets'].total_count} relevant datasets with genomic and expression data
|
713 |
+
• **Research Literature**: {results['pmc'].total_count} recent publications discussing {gene} mechanisms and clinical studies
|
714 |
+
• **Clinical Variants**: {results['clinvar'].total_count} documented variants with clinical significance
|
715 |
+
|
716 |
+
**🎯 Clinical Relevance:**
|
717 |
+
The {gene} gene is associated with various biological pathways and may have clinical implications. Current research focuses on understanding its role in disease mechanisms and potential therapeutic targets.
|
718 |
+
|
719 |
+
**⚠️ Important Note:**
|
720 |
+
This information is for research purposes. Always consult healthcare professionals for medical decisions."""
|
721 |
+
|
722 |
+
def _synthesize_disease_response(self, disease: str, results: Dict, summaries: List[str]) -> str:
|
723 |
+
"""Enhanced synthesis for disease queries using improved prompts (Feature 10)."""
|
724 |
+
# Force enhanced prompts - they are built into this module
|
725 |
+
if True: # Always use enhanced prompts
|
726 |
+
return f"""🏥 **{disease.title()} - Research & Clinical Insights** ✨ ENHANCED VERSION ✨
|
727 |
+
|
728 |
+
**📊 Evidence Base:**
|
729 |
+
{chr(10).join(f"• {summary}" for summary in summaries)}
|
730 |
+
|
731 |
+
**🔬 Pathophysiology & Disease Mechanisms:**
|
732 |
+
Based on {results.get('pmc', type('', (), {'total_count': 0})).total_count} recent peer-reviewed publications, current understanding includes:
|
733 |
+
|
734 |
+
• **Molecular Pathways**: Key cellular signaling cascades disrupted in {disease}, including inflammatory responses, metabolic dysfunction, and cell death pathways
|
735 |
+
• **Disease Initiation**: Environmental triggers, genetic predisposition, and cellular stress factors that initiate disease processes
|
736 |
+
• **Disease Progression**: How the condition evolves over time, including compensatory mechanisms and progressive dysfunction
|
737 |
+
• **Organ System Impact**: Multi-system effects and complications that develop as the disease advances
|
738 |
+
• **Biomarker Profiles**: Molecular signatures in blood, tissue, or imaging that reflect disease activity and progression
|
739 |
+
• **Meta-analyses**: Systematic reviews synthesizing evidence from multiple clinical studies and outcomes research
|
740 |
+
|
741 |
+
**🧬 Genetic & Genomic Architecture:**
|
742 |
+
• **Research Datasets**: {results.get('datasets', type('', (), {'total_count': 0})).total_count} comprehensive genomic datasets provide insights into disease biology and therapeutic targets
|
743 |
+
• **Genetic Risk Factors**: Inherited variants that increase susceptibility, including common polymorphisms and rare pathogenic mutations
|
744 |
+
• **Expression Profiling**: Tissue-specific gene expression changes that characterize disease states and severity
|
745 |
+
• **Epigenetic Modifications**: DNA methylation and histone modifications that regulate gene expression in disease contexts
|
746 |
+
• **Pharmacogenomic Factors**: Genetic variants affecting drug metabolism, efficacy, and adverse reactions specific to {disease} treatments
|
747 |
+
|
748 |
+
**🩺 Clinical Manifestations & Diagnosis:**
|
749 |
+
• **Symptom Patterns**: Early warning signs, disease progression markers, and variability in clinical presentation
|
750 |
+
• **Diagnostic Criteria**: Evidence-based guidelines for accurate diagnosis including laboratory tests, imaging, and clinical assessment
|
751 |
+
• **Disease Staging**: Classification systems that guide prognosis and treatment decisions
|
752 |
+
�� **Comorbidity Patterns**: Associated conditions that commonly occur with {disease}
|
753 |
+
|
754 |
+
**🎯 Therapeutic Landscape & Treatment:**
|
755 |
+
• **Standard of Care**: Current evidence-based treatment protocols and clinical guidelines from major medical organizations
|
756 |
+
• **Emerging Therapies**: Novel treatment approaches in clinical development including targeted therapies and immunomodulatory agents
|
757 |
+
• **Precision Medicine**: Personalized treatment strategies based on genetic profiles, biomarkers, and disease subtypes
|
758 |
+
• **Clinical Trial Landscape**: Active research studies testing new interventions and treatment combinations
|
759 |
+
• **Multidisciplinary Care**: Coordinated care approaches involving specialists, primary care, and supportive services
|
760 |
+
|
761 |
+
**🔍 Research Frontiers & Innovation:**
|
762 |
+
• **Therapeutic Development**: Drug discovery efforts targeting specific molecular pathways identified in {disease}
|
763 |
+
• **Biomarker Discovery**: Development of diagnostic, prognostic, and therapeutic response biomarkers
|
764 |
+
• **Prevention Strategies**: Research into primary and secondary prevention approaches based on risk factor modification
|
765 |
+
• **Digital Health Solutions**: Technology-enabled monitoring, diagnosis, and treatment approaches
|
766 |
+
|
767 |
+
**⚠️ Medical Disclaimer:**
|
768 |
+
This scientific summary is for research and educational purposes. Clinical decisions require consultation with qualified healthcare professionals."""
|
769 |
+
else:
|
770 |
+
# Fallback to original
|
771 |
+
return f"""🏥 **{disease.title()} Research Summary**
|
772 |
+
|
773 |
+
**📊 Data Sources:**
|
774 |
+
{chr(10).join(f"• {summary}" for summary in summaries)}
|
775 |
+
|
776 |
+
**📚 Current Research:**
|
777 |
+
Based on {results['pmc'].total_count} recent publications, research on {disease} includes:
|
778 |
+
• Molecular mechanisms and pathways
|
779 |
+
• Clinical outcomes and treatment effectiveness
|
780 |
+
• Meta-analyses of therapeutic approaches
|
781 |
+
|
782 |
+
**🧬 Genomic Insights:**
|
783 |
+
• {results['datasets'].total_count} relevant genomic datasets available
|
784 |
+
• Expression data and molecular profiles
|
785 |
+
• Potential biomarkers for diagnosis and treatment
|
786 |
+
|
787 |
+
**🔬 Clinical Significance:**
|
788 |
+
Research continues to advance our understanding of {disease}, with focus on improving diagnosis, treatment, and patient outcomes."""
|
789 |
+
|
790 |
+
def _synthesize_drug_response(self, drug: str, results: Dict, summaries: List[str]) -> str:
|
791 |
+
"""Synthesize response for drug queries."""
|
792 |
+
return f"""💊 **{drug.title()} - Clinical Information**
|
793 |
+
|
794 |
+
**📊 Data Sources:**
|
795 |
+
{chr(10).join(f"• {summary}" for summary in summaries)}
|
796 |
+
|
797 |
+
**🔬 Research Findings:**
|
798 |
+
From {results['pmc'].total_count} recent publications:
|
799 |
+
• Mechanism of action and pharmacology
|
800 |
+
• Clinical efficacy and safety profiles
|
801 |
+
• Drug interactions and contraindications
|
802 |
+
|
803 |
+
**⚗️ Clinical Applications:**
|
804 |
+
• Therapeutic uses and indications
|
805 |
+
• Dosing guidelines and administration
|
806 |
+
• Monitoring parameters and adverse effects
|
807 |
+
|
808 |
+
**⚠️ Medical Disclaimer:**
|
809 |
+
This information is for educational purposes only. Always consult healthcare professionals for medical advice and treatment decisions."""
|
810 |
+
|
811 |
+
def _synthesize_general_response(self, query: str, results: Dict, summaries: List[str]) -> str:
|
812 |
+
"""Synthesize response for general biomedical queries."""
|
813 |
+
return f"""🔬 **Biomedical Research: {query}**
|
814 |
+
|
815 |
+
**📊 Data Sources:**
|
816 |
+
{chr(10).join(f"• {summary}" for summary in summaries)}
|
817 |
+
|
818 |
+
**📚 Research Overview:**
|
819 |
+
I found relevant information across multiple biomedical databases:
|
820 |
+
• Scientific literature with recent research findings
|
821 |
+
• Genomic and molecular data
|
822 |
+
• Clinical and research datasets
|
823 |
+
|
824 |
+
**🎯 Key Areas:**
|
825 |
+
Research in this area encompasses molecular mechanisms, clinical applications, and ongoing scientific investigations.
|
826 |
+
|
827 |
+
**💡 Next Steps:**
|
828 |
+
Consider exploring specific aspects like molecular pathways, clinical outcomes, or therapeutic implications."""
|
829 |
+
|
830 |
+
def _generate_follow_up_suggestions(self, query: str, query_type: QueryType) -> List[str]:
|
831 |
+
"""Enhanced follow-up questions using improved prompt engineering (Feature 10)."""
|
832 |
+
if True: # Always use enhanced prompts
|
833 |
+
# Use enhanced, more specific follow-up suggestions
|
834 |
+
if query_type == QueryType.GENE:
|
835 |
+
return [
|
836 |
+
f"What diseases are linked to {query} mutations?",
|
837 |
+
f"Show clinical trials targeting {query}",
|
838 |
+
f"Find drugs that interact with {query} pathway"
|
839 |
+
]
|
840 |
+
elif query_type == QueryType.DISEASE:
|
841 |
+
return [
|
842 |
+
f"What genes cause {query}?",
|
843 |
+
f"Latest {query} treatment breakthroughs?",
|
844 |
+
f"Clinical trials for {query} patients"
|
845 |
+
]
|
846 |
+
elif query_type == QueryType.DRUG:
|
847 |
+
return [
|
848 |
+
f"What are {query} side effects?",
|
849 |
+
f"How does {query} work molecularly?",
|
850 |
+
f"Recent {query} efficacy studies?"
|
851 |
+
]
|
852 |
+
else:
|
853 |
+
return [
|
854 |
+
f"Genetic factors in {query}?",
|
855 |
+
f"Current research on {query}?",
|
856 |
+
f"Clinical applications of {query}?"
|
857 |
+
]
|
858 |
+
else:
|
859 |
+
# Original follow-up logic
|
860 |
+
if query_type == QueryType.GENE:
|
861 |
+
return [
|
862 |
+
f"What diseases are associated with {query}?",
|
863 |
+
f"Are there any drugs that target {query}?",
|
864 |
+
f"What are the latest clinical trials involving {query}?"
|
865 |
+
]
|
866 |
+
elif query_type == QueryType.DISEASE:
|
867 |
+
return [
|
868 |
+
f"What genes are involved in {query}?",
|
869 |
+
f"What are the current treatments for {query}?",
|
870 |
+
f"Are there any recent breakthroughs in {query} research?"
|
871 |
+
]
|
872 |
+
elif query_type == QueryType.DRUG:
|
873 |
+
return [
|
874 |
+
f"What are the side effects of {query}?",
|
875 |
+
f"How does {query} work at the molecular level?",
|
876 |
+
f"Are there any new studies on {query} effectiveness?"
|
877 |
+
]
|
878 |
+
else:
|
879 |
+
return [
|
880 |
+
"Can you be more specific about what interests you?",
|
881 |
+
"Would you like to explore the genetic aspects?",
|
882 |
+
"Are you interested in current research findings?"
|
883 |
+
]
|
884 |
+
|
885 |
+
def _update_conversation_memory(
|
886 |
+
self,
|
887 |
+
memory: ConversationMemory,
|
888 |
+
query: str,
|
889 |
+
synthesis: ScientificSynthesis
|
890 |
+
):
|
891 |
+
"""Update conversation memory with new interaction."""
|
892 |
+
memory.query_history.append(query)
|
893 |
+
memory.messages.append({
|
894 |
+
"role": "user",
|
895 |
+
"content": query,
|
896 |
+
"timestamp": datetime.now().isoformat()
|
897 |
+
})
|
898 |
+
memory.messages.append({
|
899 |
+
"role": "assistant",
|
900 |
+
"content": synthesis.response,
|
901 |
+
"timestamp": datetime.now().isoformat(),
|
902 |
+
"sources": synthesis.sources
|
903 |
+
})
|
904 |
+
|
905 |
+
# Keep only last 10 interactions for memory efficiency
|
906 |
+
if len(memory.messages) > 20:
|
907 |
+
memory.messages = memory.messages[-20:]
|
908 |
+
if len(memory.query_history) > 10:
|
909 |
+
memory.query_history = memory.query_history[-10:]
|
910 |
+
|
911 |
+
def _create_rejection_result(
|
912 |
+
self,
|
913 |
+
query: str,
|
914 |
+
guardrail_result: GuardrailResult,
|
915 |
+
start_time: datetime
|
916 |
+
) -> OrchestrationResult:
|
917 |
+
"""Create result for rejected non-biomedical queries."""
|
918 |
+
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
919 |
+
|
920 |
+
suggestions = self.guardrails.get_biomedical_suggestions(query)
|
921 |
+
|
922 |
+
response = f"""🚫 {guardrail_result.rejection_message}
|
923 |
+
|
924 |
+
**💡 Try these biomedical questions instead:**
|
925 |
+
{chr(10).join(f"• {suggestion}" for suggestion in suggestions)}"""
|
926 |
+
|
927 |
+
return OrchestrationResult(
|
928 |
+
original_query=query,
|
929 |
+
final_response=response,
|
930 |
+
sources=[],
|
931 |
+
query_classification=QueryType.UNCLEAR,
|
932 |
+
clarification_used=None,
|
933 |
+
database_results={},
|
934 |
+
synthesis=ScientificSynthesis(
|
935 |
+
response=response,
|
936 |
+
sources=[],
|
937 |
+
confidence=1.0,
|
938 |
+
methodology="biomedical_guardrails",
|
939 |
+
limitations="Query outside biomedical domain",
|
940 |
+
follow_up_suggestions=suggestions
|
941 |
+
),
|
942 |
+
conversation_memory=ConversationMemory([], []),
|
943 |
+
execution_time_ms=execution_time,
|
944 |
+
success=False,
|
945 |
+
errors=[f"Non-biomedical query: {guardrail_result.rejection_message}"],
|
946 |
+
observability_trace_id=None
|
947 |
+
)
|
gquery/src/gquery/agents/entity_resolver.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Biomedical Entity Resolution and Linking Module
|
3 |
+
|
4 |
+
Resolves and standardizes biomedical entities across databases.
|
5 |
+
This implements Feature 2.4 from the PRD.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import re
|
11 |
+
from typing import Dict, List, Optional, Tuple, Set
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
from openai import AsyncOpenAI
|
16 |
+
from pydantic import BaseModel, Field
|
17 |
+
|
18 |
+
from .config import AgentConfig, AGENT_PROMPTS
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class EntityIdentifier(BaseModel):
|
25 |
+
"""Represents a database identifier for an entity."""
|
26 |
+
database: str
|
27 |
+
identifier: str
|
28 |
+
url: Optional[str] = None
|
29 |
+
confidence: float = Field(ge=0.0, le=1.0, default=1.0)
|
30 |
+
|
31 |
+
|
32 |
+
class ResolvedEntity(BaseModel):
|
33 |
+
"""Represents a resolved biomedical entity."""
|
34 |
+
original_name: str
|
35 |
+
standardized_name: str
|
36 |
+
entity_type: str # gene, variant, disease, organism, protein
|
37 |
+
confidence: float = Field(ge=0.0, le=1.0)
|
38 |
+
identifiers: List[EntityIdentifier] = Field(default_factory=list)
|
39 |
+
synonyms: List[str] = Field(default_factory=list)
|
40 |
+
description: Optional[str] = None
|
41 |
+
organism: Optional[str] = None
|
42 |
+
resolution_timestamp: datetime = Field(default_factory=datetime.now)
|
43 |
+
|
44 |
+
|
45 |
+
class EntityResolutionResult(BaseModel):
|
46 |
+
"""Results from entity resolution process."""
|
47 |
+
resolved_entities: List[ResolvedEntity]
|
48 |
+
unresolved_entities: List[str]
|
49 |
+
resolution_confidence: float = Field(ge=0.0, le=1.0)
|
50 |
+
processing_time_ms: Optional[int] = None
|
51 |
+
metadata: Dict = Field(default_factory=dict)
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class EntityPattern:
|
56 |
+
"""Pattern for recognizing biomedical entities."""
|
57 |
+
name: str
|
58 |
+
pattern: str
|
59 |
+
entity_type: str
|
60 |
+
confidence: float
|
61 |
+
|
62 |
+
|
63 |
+
class EntityResolver:
|
64 |
+
"""Resolves and standardizes biomedical entities."""
|
65 |
+
|
66 |
+
# Known gene patterns and databases
|
67 |
+
GENE_PATTERNS = [
|
68 |
+
EntityPattern("HGNC_Symbol", r"\b[A-Z][A-Z0-9]{1,15}\b", "gene", 0.8),
|
69 |
+
EntityPattern("Gene_Name", r"\b[A-Z][a-z]+ [a-z]+ \d+\b", "gene", 0.7),
|
70 |
+
EntityPattern("Ensembl_Gene", r"\bENSG\d{11}\b", "gene", 0.95),
|
71 |
+
]
|
72 |
+
|
73 |
+
VARIANT_PATTERNS = [
|
74 |
+
EntityPattern("rs_ID", r"\brs\d+\b", "variant", 0.9),
|
75 |
+
EntityPattern("HGVS_DNA", r"\b[A-Z]+\.\d+:\w\.\d+[A-Z]>[A-Z]\b", "variant", 0.9),
|
76 |
+
EntityPattern("HGVS_Protein", r"\bp\.[A-Z][a-z]{2}\d+[A-Z][a-z]{2}\b", "variant", 0.85),
|
77 |
+
EntityPattern("Chromosome", r"\bchr\d{1,2}[XYM]?:\d+\b", "variant", 0.7),
|
78 |
+
]
|
79 |
+
|
80 |
+
DISEASE_PATTERNS = [
|
81 |
+
EntityPattern("OMIM_ID", r"\b\d{6}\b", "disease", 0.8),
|
82 |
+
EntityPattern("Disease_Name", r"\b[A-Z][a-z]+ [a-z]+ [dD]isease\b", "disease", 0.6),
|
83 |
+
]
|
84 |
+
|
85 |
+
def __init__(self, config: AgentConfig):
|
86 |
+
self.config = config
|
87 |
+
self.client = AsyncOpenAI(api_key=config.openai_api_key)
|
88 |
+
self.logger = logging.getLogger(__name__)
|
89 |
+
|
90 |
+
# Load known entity mappings
|
91 |
+
self.gene_symbols = self._load_common_gene_symbols()
|
92 |
+
self.disease_terms = self._load_common_disease_terms()
|
93 |
+
|
94 |
+
async def resolve_entities(self, entities: List[str]) -> EntityResolutionResult:
|
95 |
+
"""
|
96 |
+
Resolve a list of biomedical entities.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
entities: List of entity names to resolve
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
EntityResolutionResult with resolved entities
|
103 |
+
"""
|
104 |
+
start_time = datetime.now()
|
105 |
+
|
106 |
+
try:
|
107 |
+
resolved_entities = []
|
108 |
+
unresolved_entities = []
|
109 |
+
|
110 |
+
for entity in entities:
|
111 |
+
# First try rule-based resolution
|
112 |
+
resolved = await self._rule_based_resolution(entity)
|
113 |
+
|
114 |
+
if resolved:
|
115 |
+
resolved_entities.append(resolved)
|
116 |
+
else:
|
117 |
+
# Try LLM-based resolution
|
118 |
+
llm_resolved = await self._llm_resolution(entity)
|
119 |
+
if llm_resolved:
|
120 |
+
resolved_entities.append(llm_resolved)
|
121 |
+
else:
|
122 |
+
unresolved_entities.append(entity)
|
123 |
+
|
124 |
+
# Calculate overall confidence
|
125 |
+
if resolved_entities:
|
126 |
+
overall_confidence = sum(e.confidence for e in resolved_entities) / len(resolved_entities)
|
127 |
+
else:
|
128 |
+
overall_confidence = 0.0
|
129 |
+
|
130 |
+
# Calculate processing time
|
131 |
+
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
132 |
+
|
133 |
+
return EntityResolutionResult(
|
134 |
+
resolved_entities=resolved_entities,
|
135 |
+
unresolved_entities=unresolved_entities,
|
136 |
+
resolution_confidence=overall_confidence,
|
137 |
+
processing_time_ms=int(processing_time),
|
138 |
+
metadata={
|
139 |
+
"total_entities": len(entities),
|
140 |
+
"resolved_count": len(resolved_entities),
|
141 |
+
"resolution_methods": ["rule_based", "llm_based"]
|
142 |
+
}
|
143 |
+
)
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
self.logger.error(f"Entity resolution failed: {e}")
|
147 |
+
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
148 |
+
|
149 |
+
return EntityResolutionResult(
|
150 |
+
resolved_entities=[],
|
151 |
+
unresolved_entities=entities,
|
152 |
+
resolution_confidence=0.0,
|
153 |
+
processing_time_ms=int(processing_time),
|
154 |
+
metadata={"error": str(e)}
|
155 |
+
)
|
156 |
+
|
157 |
+
async def _rule_based_resolution(self, entity: str) -> Optional[ResolvedEntity]:
|
158 |
+
"""Resolve entity using rule-based patterns."""
|
159 |
+
|
160 |
+
entity_clean = entity.strip()
|
161 |
+
|
162 |
+
# Check gene patterns
|
163 |
+
for pattern in self.GENE_PATTERNS:
|
164 |
+
if re.match(pattern.pattern, entity_clean):
|
165 |
+
return await self._resolve_gene_entity(entity_clean, pattern)
|
166 |
+
|
167 |
+
# Check variant patterns
|
168 |
+
for pattern in self.VARIANT_PATTERNS:
|
169 |
+
if re.match(pattern.pattern, entity_clean):
|
170 |
+
return await self._resolve_variant_entity(entity_clean, pattern)
|
171 |
+
|
172 |
+
# Check disease patterns
|
173 |
+
for pattern in self.DISEASE_PATTERNS:
|
174 |
+
if re.match(pattern.pattern, entity_clean):
|
175 |
+
return await self._resolve_disease_entity(entity_clean, pattern)
|
176 |
+
|
177 |
+
# Check known gene symbols
|
178 |
+
if entity_clean.upper() in self.gene_symbols:
|
179 |
+
return ResolvedEntity(
|
180 |
+
original_name=entity,
|
181 |
+
standardized_name=entity_clean.upper(),
|
182 |
+
entity_type="gene",
|
183 |
+
confidence=0.9,
|
184 |
+
identifiers=[
|
185 |
+
EntityIdentifier(
|
186 |
+
database="HGNC",
|
187 |
+
identifier=entity_clean.upper(),
|
188 |
+
confidence=0.9
|
189 |
+
)
|
190 |
+
],
|
191 |
+
synonyms=self.gene_symbols[entity_clean.upper()].get("synonyms", [])
|
192 |
+
)
|
193 |
+
|
194 |
+
return None
|
195 |
+
|
196 |
+
async def _resolve_gene_entity(self, entity: str, pattern: EntityPattern) -> ResolvedEntity:
|
197 |
+
"""Resolve a gene entity."""
|
198 |
+
|
199 |
+
identifiers = []
|
200 |
+
synonyms = []
|
201 |
+
|
202 |
+
# Add pattern-specific identifiers
|
203 |
+
if pattern.name == "HGNC_Symbol":
|
204 |
+
identifiers.append(EntityIdentifier(
|
205 |
+
database="HGNC",
|
206 |
+
identifier=entity.upper(),
|
207 |
+
url=f"https://www.genenames.org/data/gene-symbol-report/#!/hgnc_id/{entity.upper()}",
|
208 |
+
confidence=pattern.confidence
|
209 |
+
))
|
210 |
+
elif pattern.name == "Ensembl_Gene":
|
211 |
+
identifiers.append(EntityIdentifier(
|
212 |
+
database="Ensembl",
|
213 |
+
identifier=entity,
|
214 |
+
url=f"https://www.ensembl.org/Homo_sapiens/Gene/Summary?g={entity}",
|
215 |
+
confidence=pattern.confidence
|
216 |
+
))
|
217 |
+
|
218 |
+
# Try to find additional identifiers
|
219 |
+
gene_info = self.gene_symbols.get(entity.upper(), {})
|
220 |
+
if gene_info:
|
221 |
+
synonyms = gene_info.get("synonyms", [])
|
222 |
+
if "entrez_id" in gene_info:
|
223 |
+
identifiers.append(EntityIdentifier(
|
224 |
+
database="Entrez",
|
225 |
+
identifier=gene_info["entrez_id"],
|
226 |
+
url=f"https://www.ncbi.nlm.nih.gov/gene/{gene_info['entrez_id']}",
|
227 |
+
confidence=0.95
|
228 |
+
))
|
229 |
+
|
230 |
+
return ResolvedEntity(
|
231 |
+
original_name=entity,
|
232 |
+
standardized_name=entity.upper(),
|
233 |
+
entity_type="gene",
|
234 |
+
confidence=pattern.confidence,
|
235 |
+
identifiers=identifiers,
|
236 |
+
synonyms=synonyms,
|
237 |
+
organism="Homo sapiens" # Default to human
|
238 |
+
)
|
239 |
+
|
240 |
+
async def _resolve_variant_entity(self, entity: str, pattern: EntityPattern) -> ResolvedEntity:
|
241 |
+
"""Resolve a variant entity."""
|
242 |
+
|
243 |
+
identifiers = []
|
244 |
+
|
245 |
+
if pattern.name == "rs_ID":
|
246 |
+
identifiers.append(EntityIdentifier(
|
247 |
+
database="dbSNP",
|
248 |
+
identifier=entity,
|
249 |
+
url=f"https://www.ncbi.nlm.nih.gov/snp/{entity}",
|
250 |
+
confidence=pattern.confidence
|
251 |
+
))
|
252 |
+
|
253 |
+
return ResolvedEntity(
|
254 |
+
original_name=entity,
|
255 |
+
standardized_name=entity,
|
256 |
+
entity_type="variant",
|
257 |
+
confidence=pattern.confidence,
|
258 |
+
identifiers=identifiers,
|
259 |
+
organism="Homo sapiens"
|
260 |
+
)
|
261 |
+
|
262 |
+
async def _resolve_disease_entity(self, entity: str, pattern: EntityPattern) -> ResolvedEntity:
|
263 |
+
"""Resolve a disease entity."""
|
264 |
+
|
265 |
+
identifiers = []
|
266 |
+
|
267 |
+
if pattern.name == "OMIM_ID":
|
268 |
+
identifiers.append(EntityIdentifier(
|
269 |
+
database="OMIM",
|
270 |
+
identifier=entity,
|
271 |
+
url=f"https://www.omim.org/entry/{entity}",
|
272 |
+
confidence=pattern.confidence
|
273 |
+
))
|
274 |
+
|
275 |
+
return ResolvedEntity(
|
276 |
+
original_name=entity,
|
277 |
+
standardized_name=entity,
|
278 |
+
entity_type="disease",
|
279 |
+
confidence=pattern.confidence,
|
280 |
+
identifiers=identifiers
|
281 |
+
)
|
282 |
+
|
283 |
+
async def _llm_resolution(self, entity: str) -> Optional[ResolvedEntity]:
|
284 |
+
"""Resolve entity using LLM."""
|
285 |
+
|
286 |
+
try:
|
287 |
+
prompt = AGENT_PROMPTS["entity_resolution"].format(entities=[entity])
|
288 |
+
|
289 |
+
response = await self.client.chat.completions.create(
|
290 |
+
model=self.config.model,
|
291 |
+
messages=[{"role": "user", "content": prompt}],
|
292 |
+
temperature=0.1, # Low temperature for consistent resolution
|
293 |
+
max_tokens=1000,
|
294 |
+
response_format={"type": "json_object"}
|
295 |
+
)
|
296 |
+
|
297 |
+
result = json.loads(response.choices[0].message.content)
|
298 |
+
|
299 |
+
# Parse LLM response
|
300 |
+
if "entities" in result and result["entities"]:
|
301 |
+
entity_data = result["entities"][0] # Take first resolved entity
|
302 |
+
|
303 |
+
# Convert to ResolvedEntity
|
304 |
+
identifiers = []
|
305 |
+
if "identifiers" in entity_data:
|
306 |
+
for db, id_val in entity_data["identifiers"].items():
|
307 |
+
identifiers.append(EntityIdentifier(
|
308 |
+
database=db,
|
309 |
+
identifier=id_val,
|
310 |
+
confidence=0.8
|
311 |
+
))
|
312 |
+
|
313 |
+
return ResolvedEntity(
|
314 |
+
original_name=entity,
|
315 |
+
standardized_name=entity_data.get("standardized_name", entity),
|
316 |
+
entity_type=entity_data.get("entity_type", "unknown"),
|
317 |
+
confidence=entity_data.get("confidence", 0.7),
|
318 |
+
identifiers=identifiers,
|
319 |
+
synonyms=entity_data.get("synonyms", []),
|
320 |
+
description=entity_data.get("description"),
|
321 |
+
organism=entity_data.get("organism")
|
322 |
+
)
|
323 |
+
|
324 |
+
except Exception as e:
|
325 |
+
self.logger.warning(f"LLM entity resolution failed for {entity}: {e}")
|
326 |
+
|
327 |
+
return None
|
328 |
+
|
329 |
+
def _load_common_gene_symbols(self) -> Dict[str, Dict]:
|
330 |
+
"""Load common gene symbols and their mappings."""
|
331 |
+
|
332 |
+
# In a real implementation, this would load from a database or file
|
333 |
+
# For now, we'll use a small sample
|
334 |
+
return {
|
335 |
+
"BRCA1": {
|
336 |
+
"entrez_id": "672",
|
337 |
+
"synonyms": ["breast cancer 1", "BRCC1", "FANCS"],
|
338 |
+
"description": "BRCA1 DNA repair associated"
|
339 |
+
},
|
340 |
+
"BRCA2": {
|
341 |
+
"entrez_id": "675",
|
342 |
+
"synonyms": ["breast cancer 2", "BRCC2", "FANCD1"],
|
343 |
+
"description": "BRCA2 DNA repair associated"
|
344 |
+
},
|
345 |
+
"TP53": {
|
346 |
+
"entrez_id": "7157",
|
347 |
+
"synonyms": ["tumor protein p53", "P53", "TRP53"],
|
348 |
+
"description": "tumor protein p53"
|
349 |
+
},
|
350 |
+
"EGFR": {
|
351 |
+
"entrez_id": "1956",
|
352 |
+
"synonyms": ["epidermal growth factor receptor", "ERBB1", "HER1"],
|
353 |
+
"description": "epidermal growth factor receptor"
|
354 |
+
},
|
355 |
+
"KRAS": {
|
356 |
+
"entrez_id": "3845",
|
357 |
+
"synonyms": ["KRAS proto-oncogene", "K-RAS", "RASK2"],
|
358 |
+
"description": "KRAS proto-oncogene, GTPase"
|
359 |
+
}
|
360 |
+
}
|
361 |
+
|
362 |
+
def _load_common_disease_terms(self) -> Dict[str, Dict]:
|
363 |
+
"""Load common disease terms and their mappings."""
|
364 |
+
|
365 |
+
return {
|
366 |
+
"breast cancer": {
|
367 |
+
"omim_id": "114480",
|
368 |
+
"synonyms": ["mammary carcinoma", "breast carcinoma"],
|
369 |
+
"description": "malignant neoplasm of breast"
|
370 |
+
},
|
371 |
+
"alzheimer disease": {
|
372 |
+
"omim_id": "104300",
|
373 |
+
"synonyms": ["alzheimer's disease", "AD"],
|
374 |
+
"description": "neurodegenerative disease"
|
375 |
+
}
|
376 |
+
}
|
377 |
+
|
378 |
+
async def standardize_gene_symbol(self, gene_symbol: str) -> Optional[str]:
|
379 |
+
"""Standardize a gene symbol to HGNC format."""
|
380 |
+
|
381 |
+
# Clean the input
|
382 |
+
clean_symbol = re.sub(r'[^\w]', '', gene_symbol).upper()
|
383 |
+
|
384 |
+
# Check if it's already a known symbol
|
385 |
+
if clean_symbol in self.gene_symbols:
|
386 |
+
return clean_symbol
|
387 |
+
|
388 |
+
# Check synonyms
|
389 |
+
for standard_symbol, info in self.gene_symbols.items():
|
390 |
+
if clean_symbol in [s.upper() for s in info.get("synonyms", [])]:
|
391 |
+
return standard_symbol
|
392 |
+
|
393 |
+
# Use LLM as fallback
|
394 |
+
try:
|
395 |
+
resolved = await self._llm_resolution(gene_symbol)
|
396 |
+
if resolved and resolved.entity_type == "gene":
|
397 |
+
return resolved.standardized_name
|
398 |
+
except Exception:
|
399 |
+
pass
|
400 |
+
|
401 |
+
return None
|
402 |
+
|
403 |
+
async def find_entity_relationships(
|
404 |
+
self,
|
405 |
+
entities: List[ResolvedEntity]
|
406 |
+
) -> Dict[str, List[str]]:
|
407 |
+
"""Find relationships between resolved entities."""
|
408 |
+
|
409 |
+
relationships = {}
|
410 |
+
|
411 |
+
# Group entities by type
|
412 |
+
genes = [e for e in entities if e.entity_type == "gene"]
|
413 |
+
variants = [e for e in entities if e.entity_type == "variant"]
|
414 |
+
diseases = [e for e in entities if e.entity_type == "disease"]
|
415 |
+
|
416 |
+
# Gene-disease relationships
|
417 |
+
if genes and diseases:
|
418 |
+
for gene in genes:
|
419 |
+
for disease in diseases:
|
420 |
+
key = f"{gene.standardized_name}-{disease.standardized_name}"
|
421 |
+
relationships[key] = ["potential_association"]
|
422 |
+
|
423 |
+
# Gene-variant relationships
|
424 |
+
if genes and variants:
|
425 |
+
for gene in genes:
|
426 |
+
for variant in variants:
|
427 |
+
key = f"{gene.standardized_name}-{variant.standardized_name}"
|
428 |
+
relationships[key] = ["variant_in_gene"]
|
429 |
+
|
430 |
+
return relationships
|
431 |
+
|
432 |
+
|
433 |
+
# Convenience function for entity resolution
|
434 |
+
async def resolve_biomedical_entities(
|
435 |
+
entities: List[str],
|
436 |
+
config: Optional[AgentConfig] = None
|
437 |
+
) -> EntityResolutionResult:
|
438 |
+
"""
|
439 |
+
Convenience function to resolve biomedical entities.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
entities: List of entity names to resolve
|
443 |
+
config: Optional agent configuration
|
444 |
+
|
445 |
+
Returns:
|
446 |
+
EntityResolutionResult with resolved entities
|
447 |
+
"""
|
448 |
+
if config is None:
|
449 |
+
config = AgentConfig.from_env()
|
450 |
+
|
451 |
+
resolver = EntityResolver(config)
|
452 |
+
return await resolver.resolve_entities(entities)
|
gquery/src/gquery/agents/orchestrator.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Intelligent Agent Orchestration Module
|
3 |
+
|
4 |
+
Implements the core orchestration logic using LangGraph for dynamic workflow management.
|
5 |
+
This implements Feature 2.1 from the PRD.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
import logging
|
10 |
+
from typing import Dict, List, Optional, Any, TypedDict
|
11 |
+
from datetime import datetime
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
from langgraph.graph import StateGraph, END
|
15 |
+
from langchain_openai import ChatOpenAI
|
16 |
+
from langchain.schema import BaseMessage, HumanMessage, AIMessage
|
17 |
+
|
18 |
+
from ..tools.datasets_client import DatasetsClient
|
19 |
+
from ..tools.pmc_client import PMCClient
|
20 |
+
from ..tools.clinvar_client import ClinVarClient
|
21 |
+
from .query_analyzer import QueryAnalyzer, QueryAnalysis
|
22 |
+
from .config import AgentConfig
|
23 |
+
from .synthesis import DataSynthesizer
|
24 |
+
from .biomedical_guardrails import BiomedicalGuardrails, GuardrailResult, QueryDomain
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class AgentState(TypedDict):
|
31 |
+
"""State object for the LangGraph workflow."""
|
32 |
+
query: str
|
33 |
+
guardrail_result: Optional[GuardrailResult]
|
34 |
+
analysis: Optional[QueryAnalysis]
|
35 |
+
datasets_results: Optional[Dict]
|
36 |
+
pmc_results: Optional[Dict]
|
37 |
+
clinvar_results: Optional[Dict]
|
38 |
+
synthesis: Optional[Dict]
|
39 |
+
errors: List[str]
|
40 |
+
metadata: Dict[str, Any]
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class OrchestrationResult:
|
45 |
+
"""Result from the orchestration process."""
|
46 |
+
query: str
|
47 |
+
guardrail_result: GuardrailResult
|
48 |
+
analysis: Optional[QueryAnalysis]
|
49 |
+
database_results: Dict[str, Any]
|
50 |
+
synthesis: Optional[Dict]
|
51 |
+
execution_time_ms: int
|
52 |
+
success: bool
|
53 |
+
errors: List[str]
|
54 |
+
metadata: Dict[str, Any]
|
55 |
+
metadata: Dict[str, Any]
|
56 |
+
|
57 |
+
|
58 |
+
class GQueryOrchestrator:
|
59 |
+
"""Main orchestrator that coordinates AI agents and database queries."""
|
60 |
+
|
61 |
+
def __init__(self, config: Optional[AgentConfig] = None):
|
62 |
+
self.config = config or AgentConfig.from_env()
|
63 |
+
self.logger = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
# Initialize biomedical guardrails (HIGHEST PRIORITY per manager feedback)
|
66 |
+
self.guardrails = BiomedicalGuardrails()
|
67 |
+
|
68 |
+
# Initialize components
|
69 |
+
self.query_analyzer = QueryAnalyzer(self.config)
|
70 |
+
self.synthesizer = DataSynthesizer(self.config)
|
71 |
+
self.llm = ChatOpenAI(
|
72 |
+
openai_api_key=self.config.openai_api_key,
|
73 |
+
model_name=self.config.model,
|
74 |
+
temperature=self.config.temperature
|
75 |
+
)
|
76 |
+
|
77 |
+
# Initialize database clients
|
78 |
+
self.datasets_client = DatasetsClient()
|
79 |
+
self.pmc_client = PMCClient()
|
80 |
+
self.clinvar_client = ClinVarClient()
|
81 |
+
|
82 |
+
# Build the workflow graph
|
83 |
+
self.workflow = self._build_workflow()
|
84 |
+
|
85 |
+
def _build_workflow(self) -> StateGraph:
|
86 |
+
"""Build the LangGraph workflow for orchestration."""
|
87 |
+
|
88 |
+
# Define the workflow graph
|
89 |
+
workflow = StateGraph(AgentState)
|
90 |
+
|
91 |
+
# Add nodes
|
92 |
+
workflow.add_node("validate_guardrails", self._validate_guardrails_node)
|
93 |
+
workflow.add_node("analyze_query", self._analyze_query_node)
|
94 |
+
workflow.add_node("plan_execution", self._plan_execution_node)
|
95 |
+
workflow.add_node("query_datasets", self._query_datasets_node)
|
96 |
+
workflow.add_node("query_pmc", self._query_pmc_node)
|
97 |
+
workflow.add_node("query_clinvar", self._query_clinvar_node)
|
98 |
+
workflow.add_node("synthesize_results", self._synthesize_results_node)
|
99 |
+
workflow.add_node("handle_errors", self._handle_errors_node)
|
100 |
+
|
101 |
+
# Define the flow - START WITH GUARDRAILS VALIDATION
|
102 |
+
workflow.set_entry_point("validate_guardrails")
|
103 |
+
|
104 |
+
# From validate_guardrails, either continue to analysis or end with rejection
|
105 |
+
workflow.add_conditional_edges(
|
106 |
+
"validate_guardrails",
|
107 |
+
self._should_continue_after_guardrails,
|
108 |
+
{
|
109 |
+
"continue": "analyze_query",
|
110 |
+
"reject": END
|
111 |
+
}
|
112 |
+
)
|
113 |
+
|
114 |
+
# From analyze_query, go to plan_execution or handle_errors
|
115 |
+
workflow.add_conditional_edges(
|
116 |
+
"analyze_query",
|
117 |
+
self._should_continue_after_analysis,
|
118 |
+
{
|
119 |
+
"continue": "plan_execution",
|
120 |
+
"error": "handle_errors"
|
121 |
+
}
|
122 |
+
)
|
123 |
+
|
124 |
+
# From plan_execution, branch to database queries
|
125 |
+
workflow.add_conditional_edges(
|
126 |
+
"plan_execution",
|
127 |
+
self._determine_database_queries,
|
128 |
+
{
|
129 |
+
"datasets_only": "query_datasets",
|
130 |
+
"pmc_only": "query_pmc",
|
131 |
+
"clinvar_only": "query_clinvar",
|
132 |
+
"multiple": "query_datasets", # Start with datasets for multiple
|
133 |
+
"error": "handle_errors"
|
134 |
+
}
|
135 |
+
)
|
136 |
+
|
137 |
+
# Database query flows
|
138 |
+
workflow.add_conditional_edges(
|
139 |
+
"query_datasets",
|
140 |
+
self._continue_after_datasets,
|
141 |
+
{
|
142 |
+
"query_pmc": "query_pmc",
|
143 |
+
"query_clinvar": "query_clinvar",
|
144 |
+
"synthesize": "synthesize_results",
|
145 |
+
"end": END
|
146 |
+
}
|
147 |
+
)
|
148 |
+
|
149 |
+
workflow.add_conditional_edges(
|
150 |
+
"query_pmc",
|
151 |
+
self._continue_after_pmc,
|
152 |
+
{
|
153 |
+
"query_clinvar": "query_clinvar",
|
154 |
+
"synthesize": "synthesize_results",
|
155 |
+
"end": END
|
156 |
+
}
|
157 |
+
)
|
158 |
+
|
159 |
+
workflow.add_conditional_edges(
|
160 |
+
"query_clinvar",
|
161 |
+
self._continue_after_clinvar,
|
162 |
+
{
|
163 |
+
"synthesize": "synthesize_results",
|
164 |
+
"end": END
|
165 |
+
}
|
166 |
+
)
|
167 |
+
|
168 |
+
# Final nodes
|
169 |
+
workflow.add_edge("synthesize_results", END)
|
170 |
+
workflow.add_edge("handle_errors", END)
|
171 |
+
|
172 |
+
return workflow.compile()
|
173 |
+
|
174 |
+
async def orchestrate(self, query: str) -> OrchestrationResult:
|
175 |
+
"""
|
176 |
+
Main orchestration method that processes a user query.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
query: The user's natural language query
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
OrchestrationResult with all processing results
|
183 |
+
"""
|
184 |
+
start_time = datetime.now()
|
185 |
+
|
186 |
+
try:
|
187 |
+
# Initialize state
|
188 |
+
initial_state: AgentState = {
|
189 |
+
"query": query,
|
190 |
+
"guardrail_result": None,
|
191 |
+
"analysis": None,
|
192 |
+
"datasets_results": None,
|
193 |
+
"pmc_results": None,
|
194 |
+
"clinvar_results": None,
|
195 |
+
"synthesis": None,
|
196 |
+
"errors": [],
|
197 |
+
"metadata": {
|
198 |
+
"start_time": start_time.isoformat(),
|
199 |
+
"config": {
|
200 |
+
"model": self.config.model,
|
201 |
+
"temperature": self.config.temperature
|
202 |
+
}
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
# Execute the workflow
|
207 |
+
final_state = await self.workflow.ainvoke(initial_state)
|
208 |
+
|
209 |
+
# Calculate execution time
|
210 |
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
211 |
+
|
212 |
+
# Prepare results
|
213 |
+
database_results = {
|
214 |
+
"datasets": final_state.get("datasets_results"),
|
215 |
+
"pmc": final_state.get("pmc_results"),
|
216 |
+
"clinvar": final_state.get("clinvar_results")
|
217 |
+
}
|
218 |
+
|
219 |
+
# Filter out None results
|
220 |
+
database_results = {k: v for k, v in database_results.items() if v is not None}
|
221 |
+
|
222 |
+
return OrchestrationResult(
|
223 |
+
query=query,
|
224 |
+
guardrail_result=final_state.get("guardrail_result"),
|
225 |
+
analysis=final_state.get("analysis"),
|
226 |
+
database_results=database_results,
|
227 |
+
synthesis=final_state.get("synthesis"),
|
228 |
+
execution_time_ms=int(execution_time),
|
229 |
+
success=len(final_state["errors"]) == 0,
|
230 |
+
errors=final_state["errors"],
|
231 |
+
metadata=final_state["metadata"]
|
232 |
+
)
|
233 |
+
|
234 |
+
except Exception as e:
|
235 |
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
236 |
+
self.logger.error(f"Orchestration failed: {e}")
|
237 |
+
|
238 |
+
return OrchestrationResult(
|
239 |
+
query=query,
|
240 |
+
guardrail_result=None,
|
241 |
+
analysis=None,
|
242 |
+
database_results={},
|
243 |
+
synthesis=None,
|
244 |
+
execution_time_ms=int(execution_time),
|
245 |
+
success=False,
|
246 |
+
errors=[str(e)],
|
247 |
+
metadata={"error": "orchestration_failed"}
|
248 |
+
)
|
249 |
+
|
250 |
+
# Workflow node implementations
|
251 |
+
|
252 |
+
async def _validate_guardrails_node(self, state: AgentState) -> AgentState:
|
253 |
+
"""
|
254 |
+
First step: Validate that the query is within biomedical domain.
|
255 |
+
|
256 |
+
This is the HIGHEST PRIORITY feature based on manager feedback:
|
257 |
+
"TRUST IS THE MOST IMPORTANT THING"
|
258 |
+
"""
|
259 |
+
try:
|
260 |
+
guardrail_result = self.guardrails.validate_query(state["query"])
|
261 |
+
state["guardrail_result"] = guardrail_result
|
262 |
+
|
263 |
+
# Log the validation result
|
264 |
+
self.logger.info(
|
265 |
+
f"Guardrail validation: domain={guardrail_result.domain.value}, "
|
266 |
+
f"valid={guardrail_result.is_valid}, confidence={guardrail_result.confidence:.2f}"
|
267 |
+
)
|
268 |
+
|
269 |
+
# Add guardrail metadata
|
270 |
+
state["metadata"]["guardrail_validation"] = {
|
271 |
+
"domain": guardrail_result.domain.value,
|
272 |
+
"confidence": guardrail_result.confidence,
|
273 |
+
"biomedical_score": guardrail_result.biomedical_score,
|
274 |
+
"non_biomedical_score": guardrail_result.non_biomedical_score,
|
275 |
+
"processing_time_ms": guardrail_result.processing_time_ms,
|
276 |
+
"timestamp": datetime.now().isoformat()
|
277 |
+
}
|
278 |
+
|
279 |
+
# If not valid, add the rejection message as an error for proper handling
|
280 |
+
if not guardrail_result.is_valid:
|
281 |
+
state["errors"].append(f"GUARDRAIL_REJECTION: {guardrail_result.rejection_message}")
|
282 |
+
self.logger.warning(f"Query rejected by guardrails: {state['query']}")
|
283 |
+
|
284 |
+
except Exception as e:
|
285 |
+
error_msg = f"Guardrail validation failed: {e}"
|
286 |
+
state["errors"].append(error_msg)
|
287 |
+
self.logger.error(error_msg)
|
288 |
+
# Default to rejection on error for safety
|
289 |
+
from .biomedical_guardrails import GuardrailResult, QueryDomain
|
290 |
+
state["guardrail_result"] = GuardrailResult(
|
291 |
+
is_valid=False,
|
292 |
+
domain=QueryDomain.NON_BIOMEDICAL,
|
293 |
+
confidence=1.0,
|
294 |
+
rejection_message="Sorry, there was an issue validating your query. Please try again with a biomedical question."
|
295 |
+
)
|
296 |
+
|
297 |
+
return state
|
298 |
+
|
299 |
+
async def _analyze_query_node(self, state: AgentState) -> AgentState:
|
300 |
+
"""Analyze the user query."""
|
301 |
+
try:
|
302 |
+
analysis = await self.query_analyzer.analyze_query(state["query"])
|
303 |
+
state["analysis"] = analysis
|
304 |
+
self.logger.info(f"Query analyzed: {analysis.query_type.value}")
|
305 |
+
except Exception as e:
|
306 |
+
state["errors"].append(f"Query analysis failed: {e}")
|
307 |
+
self.logger.error(f"Query analysis failed: {e}")
|
308 |
+
|
309 |
+
return state
|
310 |
+
|
311 |
+
async def _plan_execution_node(self, state: AgentState) -> AgentState:
|
312 |
+
"""Plan the execution strategy based on analysis."""
|
313 |
+
if not state["analysis"]:
|
314 |
+
state["errors"].append("No analysis available for planning")
|
315 |
+
return state
|
316 |
+
|
317 |
+
analysis = state["analysis"]
|
318 |
+
|
319 |
+
# Add execution plan to metadata
|
320 |
+
state["metadata"]["execution_plan"] = {
|
321 |
+
"databases": analysis.databases_needed,
|
322 |
+
"complexity": analysis.complexity,
|
323 |
+
"estimated_time": len(analysis.databases_needed) * 2000 # ms
|
324 |
+
}
|
325 |
+
|
326 |
+
self.logger.info(f"Execution planned for databases: {analysis.databases_needed}")
|
327 |
+
return state
|
328 |
+
|
329 |
+
async def _query_datasets_node(self, state: AgentState) -> AgentState:
|
330 |
+
"""Query the NCBI Datasets database."""
|
331 |
+
# Enhanced logic: Query datasets if explicitly needed OR if this is a comprehensive biomedical query
|
332 |
+
should_query = (
|
333 |
+
state["analysis"] and "datasets" in state["analysis"].databases_needed
|
334 |
+
) or (
|
335 |
+
# Fallback: Query for any gene-related query
|
336 |
+
state["analysis"] and any(e.entity_type == "gene" for e in state["analysis"].entities)
|
337 |
+
)
|
338 |
+
|
339 |
+
if not should_query:
|
340 |
+
self.logger.info("Skipping Datasets query - no genes found or not requested")
|
341 |
+
return state
|
342 |
+
|
343 |
+
try:
|
344 |
+
# Extract gene entities for comprehensive datasets query
|
345 |
+
gene_entities = [e for e in state["analysis"].entities if e.entity_type == "gene"]
|
346 |
+
|
347 |
+
if gene_entities:
|
348 |
+
# Use enhanced comprehensive gene data retrieval
|
349 |
+
gene_symbol = gene_entities[0].name
|
350 |
+
try:
|
351 |
+
# Get comprehensive gene data including expression, proteins, and datasets
|
352 |
+
result = await self.datasets_client.get_comprehensive_gene_data(
|
353 |
+
gene_symbol=gene_symbol,
|
354 |
+
taxon_id=9606, # Human
|
355 |
+
include_expression=True,
|
356 |
+
include_proteins=True,
|
357 |
+
include_datasets=True
|
358 |
+
)
|
359 |
+
|
360 |
+
if result and "error" not in result:
|
361 |
+
state["datasets_results"] = {
|
362 |
+
"comprehensive_data": result,
|
363 |
+
"gene_symbol": gene_symbol,
|
364 |
+
"query_type": "comprehensive_gene_analysis",
|
365 |
+
"data_types": result.get("summary", {}).get("data_types_available", []),
|
366 |
+
"timestamp": datetime.now().isoformat()
|
367 |
+
}
|
368 |
+
self.logger.info(f"Comprehensive datasets query completed for gene: {gene_symbol}")
|
369 |
+
self.logger.info(f"Data types retrieved: {result.get('summary', {}).get('data_types_available', [])}")
|
370 |
+
else:
|
371 |
+
# Fallback to basic gene lookup
|
372 |
+
basic_result = await self.datasets_client.get_gene_by_symbol(gene_symbol)
|
373 |
+
if basic_result:
|
374 |
+
state["datasets_results"] = {
|
375 |
+
"gene_info": basic_result.model_dump() if hasattr(basic_result, 'model_dump') else basic_result,
|
376 |
+
"gene_symbol": gene_symbol,
|
377 |
+
"query_type": "basic_gene_lookup",
|
378 |
+
"timestamp": datetime.now().isoformat()
|
379 |
+
}
|
380 |
+
self.logger.info(f"Basic datasets query completed for gene: {gene_symbol}")
|
381 |
+
else:
|
382 |
+
state["datasets_results"] = {"message": f"No gene information found for {gene_symbol}"}
|
383 |
+
|
384 |
+
except Exception as e:
|
385 |
+
self.logger.warning(f"Comprehensive datasets query failed for {gene_symbol}: {e}")
|
386 |
+
# Try basic fallback
|
387 |
+
try:
|
388 |
+
basic_result = await self.datasets_client.get_gene_by_symbol(gene_symbol)
|
389 |
+
if basic_result:
|
390 |
+
state["datasets_results"] = {
|
391 |
+
"gene_info": basic_result.model_dump() if hasattr(basic_result, 'model_dump') else basic_result,
|
392 |
+
"gene_symbol": gene_symbol,
|
393 |
+
"query_type": "basic_gene_lookup",
|
394 |
+
"timestamp": datetime.now().isoformat()
|
395 |
+
}
|
396 |
+
self.logger.info(f"Fallback basic datasets query completed for gene: {gene_symbol}")
|
397 |
+
else:
|
398 |
+
state["datasets_results"] = {"message": f"Gene lookup failed for {gene_symbol}: {str(e)}"}
|
399 |
+
except Exception as fallback_error:
|
400 |
+
state["datasets_results"] = {"message": f"Gene lookup failed for {gene_symbol}: {str(fallback_error)}"}
|
401 |
+
else:
|
402 |
+
state["datasets_results"] = {"message": "No gene entities found for datasets query"}
|
403 |
+
|
404 |
+
except Exception as e:
|
405 |
+
error_msg = f"Datasets query failed: {e}"
|
406 |
+
state["errors"].append(error_msg)
|
407 |
+
self.logger.error(error_msg)
|
408 |
+
|
409 |
+
return state
|
410 |
+
|
411 |
+
async def _query_pmc_node(self, state: AgentState) -> AgentState:
|
412 |
+
"""Query the PMC literature database."""
|
413 |
+
# Enhanced logic: Query PMC if explicitly needed OR if this is any biomedical query
|
414 |
+
should_query = (
|
415 |
+
state["analysis"] and "pmc" in state["analysis"].databases_needed
|
416 |
+
) or (
|
417 |
+
# Fallback: Query for any biomedical entity (genes, diseases, variants)
|
418 |
+
state["analysis"] and any(
|
419 |
+
len(getattr(state["analysis"], attr, [])) > 0
|
420 |
+
for attr in ["entities"]
|
421 |
+
if hasattr(state["analysis"], attr)
|
422 |
+
)
|
423 |
+
)
|
424 |
+
|
425 |
+
if not should_query:
|
426 |
+
self.logger.info("Skipping PMC query - no biomedical entities found")
|
427 |
+
return state
|
428 |
+
|
429 |
+
try:
|
430 |
+
# Create search query from entities
|
431 |
+
entities = [e.name for e in state["analysis"].entities]
|
432 |
+
search_query = " ".join(entities)
|
433 |
+
|
434 |
+
if search_query:
|
435 |
+
async with self.pmc_client:
|
436 |
+
result = await self.pmc_client.search_articles(search_query, max_results=10)
|
437 |
+
state["pmc_results"] = {
|
438 |
+
"articles": result.results,
|
439 |
+
"search_query": search_query,
|
440 |
+
"total_count": result.total_count,
|
441 |
+
"timestamp": datetime.now().isoformat()
|
442 |
+
}
|
443 |
+
self.logger.info(f"PMC query completed for: {search_query}")
|
444 |
+
else:
|
445 |
+
state["pmc_results"] = {"message": "No search terms found for PMC query"}
|
446 |
+
|
447 |
+
except Exception as e:
|
448 |
+
error_msg = f"PMC query failed: {e}"
|
449 |
+
state["errors"].append(error_msg)
|
450 |
+
self.logger.error(error_msg)
|
451 |
+
|
452 |
+
return state
|
453 |
+
|
454 |
+
async def _query_clinvar_node(self, state: AgentState) -> AgentState:
|
455 |
+
"""Query the ClinVar database."""
|
456 |
+
# Enhanced logic: Query ClinVar if explicitly needed OR if genes/variants are mentioned
|
457 |
+
should_query = (
|
458 |
+
state["analysis"] and "clinvar" in state["analysis"].databases_needed
|
459 |
+
) or (
|
460 |
+
# Fallback: Query for any gene or variant entity
|
461 |
+
state["analysis"] and any(
|
462 |
+
e.entity_type in ["gene", "variant"] for e in state["analysis"].entities
|
463 |
+
)
|
464 |
+
)
|
465 |
+
|
466 |
+
if not should_query:
|
467 |
+
self.logger.info("Skipping ClinVar query - no genes or variants found")
|
468 |
+
return state
|
469 |
+
|
470 |
+
try:
|
471 |
+
# Extract gene entities for ClinVar query
|
472 |
+
gene_entities = [e for e in state["analysis"].entities if e.entity_type == "gene"]
|
473 |
+
|
474 |
+
if gene_entities:
|
475 |
+
gene_symbol = gene_entities[0].name
|
476 |
+
result = await self.clinvar_client.search_variants_by_gene(gene_symbol, max_results=20)
|
477 |
+
state["clinvar_results"] = {
|
478 |
+
"variants": result.results, # Extract the actual variants from the response
|
479 |
+
"gene": gene_symbol,
|
480 |
+
"total_count": result.total_count,
|
481 |
+
"query": result.query,
|
482 |
+
"timestamp": datetime.now().isoformat()
|
483 |
+
}
|
484 |
+
self.logger.info(f"ClinVar query completed for gene: {gene_symbol}, found {len(result.results)} variants")
|
485 |
+
else:
|
486 |
+
state["clinvar_results"] = {"message": "No gene entities found for ClinVar query"}
|
487 |
+
|
488 |
+
except Exception as e:
|
489 |
+
error_msg = f"ClinVar query failed: {e}"
|
490 |
+
state["errors"].append(error_msg)
|
491 |
+
self.logger.error(error_msg)
|
492 |
+
|
493 |
+
return state
|
494 |
+
|
495 |
+
async def _synthesize_results_node(self, state: AgentState) -> AgentState:
|
496 |
+
"""Synthesize results from all databases."""
|
497 |
+
try:
|
498 |
+
# Check if we have any results to synthesize
|
499 |
+
has_results = any([
|
500 |
+
state.get("datasets_results"),
|
501 |
+
state.get("pmc_results"),
|
502 |
+
state.get("clinvar_results")
|
503 |
+
])
|
504 |
+
|
505 |
+
if has_results:
|
506 |
+
synthesis = await self.synthesizer.synthesize_data(
|
507 |
+
query=state["query"],
|
508 |
+
datasets_data=state.get("datasets_results"),
|
509 |
+
pmc_data=state.get("pmc_results"),
|
510 |
+
clinvar_data=state.get("clinvar_results")
|
511 |
+
)
|
512 |
+
# Convert SynthesisResult to dict for state storage
|
513 |
+
state["synthesis"] = synthesis.model_dump() if hasattr(synthesis, 'model_dump') else synthesis.__dict__
|
514 |
+
self.logger.info("Data synthesis completed")
|
515 |
+
else:
|
516 |
+
state["synthesis"] = {"message": "No data available for synthesis"}
|
517 |
+
|
518 |
+
except Exception as e:
|
519 |
+
error_msg = f"Synthesis failed: {e}"
|
520 |
+
state["errors"].append(error_msg)
|
521 |
+
self.logger.error(error_msg)
|
522 |
+
|
523 |
+
return state
|
524 |
+
|
525 |
+
async def _handle_errors_node(self, state: AgentState) -> AgentState:
|
526 |
+
"""Handle errors and attempt recovery."""
|
527 |
+
if state["errors"]:
|
528 |
+
self.logger.warning(f"Handling {len(state['errors'])} errors")
|
529 |
+
|
530 |
+
# Add error recovery metadata
|
531 |
+
state["metadata"]["error_recovery"] = {
|
532 |
+
"attempted": True,
|
533 |
+
"error_count": len(state["errors"]),
|
534 |
+
"timestamp": datetime.now().isoformat()
|
535 |
+
}
|
536 |
+
|
537 |
+
return state
|
538 |
+
|
539 |
+
# Conditional edge functions
|
540 |
+
|
541 |
+
def _should_continue_after_guardrails(self, state: AgentState) -> str:
|
542 |
+
"""Determine if we should continue after guardrail validation."""
|
543 |
+
guardrail_result = state.get("guardrail_result")
|
544 |
+
if guardrail_result and guardrail_result.is_valid:
|
545 |
+
return "continue"
|
546 |
+
return "reject"
|
547 |
+
|
548 |
+
def _should_continue_after_analysis(self, state: AgentState) -> str:
|
549 |
+
"""Determine if we should continue after analysis."""
|
550 |
+
if state["analysis"] and state["analysis"].confidence > 0.3:
|
551 |
+
return "continue"
|
552 |
+
return "error"
|
553 |
+
|
554 |
+
def _determine_database_queries(self, state: AgentState) -> str:
|
555 |
+
"""Determine which databases to query based on analysis."""
|
556 |
+
if not state["analysis"]:
|
557 |
+
return "error"
|
558 |
+
|
559 |
+
databases = state["analysis"].databases_needed
|
560 |
+
|
561 |
+
if len(databases) == 1:
|
562 |
+
if "datasets" in databases:
|
563 |
+
return "datasets_only"
|
564 |
+
elif "pmc" in databases:
|
565 |
+
return "pmc_only"
|
566 |
+
elif "clinvar" in databases:
|
567 |
+
return "clinvar_only"
|
568 |
+
|
569 |
+
return "multiple"
|
570 |
+
|
571 |
+
def _continue_after_datasets(self, state: AgentState) -> str:
|
572 |
+
"""Determine next step after datasets query."""
|
573 |
+
if not state["analysis"]:
|
574 |
+
return "end"
|
575 |
+
|
576 |
+
databases = state["analysis"].databases_needed
|
577 |
+
|
578 |
+
if "pmc" in databases and not state.get("pmc_results"):
|
579 |
+
return "query_pmc"
|
580 |
+
elif "clinvar" in databases and not state.get("clinvar_results"):
|
581 |
+
return "query_clinvar"
|
582 |
+
elif len(databases) > 1:
|
583 |
+
return "synthesize"
|
584 |
+
|
585 |
+
return "end"
|
586 |
+
|
587 |
+
def _continue_after_pmc(self, state: AgentState) -> str:
|
588 |
+
"""Determine next step after PMC query."""
|
589 |
+
if not state["analysis"]:
|
590 |
+
return "end"
|
591 |
+
|
592 |
+
databases = state["analysis"].databases_needed
|
593 |
+
|
594 |
+
if "clinvar" in databases and not state.get("clinvar_results"):
|
595 |
+
return "query_clinvar"
|
596 |
+
elif len(databases) > 1:
|
597 |
+
return "synthesize"
|
598 |
+
|
599 |
+
return "end"
|
600 |
+
|
601 |
+
def _continue_after_clinvar(self, state: AgentState) -> str:
|
602 |
+
"""Determine next step after ClinVar query."""
|
603 |
+
if not state["analysis"]:
|
604 |
+
return "end"
|
605 |
+
|
606 |
+
databases = state["analysis"].databases_needed
|
607 |
+
|
608 |
+
if len(databases) > 1:
|
609 |
+
return "synthesize"
|
610 |
+
|
611 |
+
return "end"
|
612 |
+
|
613 |
+
|
614 |
+
# Convenience function for easy orchestration
|
615 |
+
async def orchestrate_query(query: str, config: Optional[AgentConfig] = None) -> OrchestrationResult:
|
616 |
+
"""
|
617 |
+
Convenience function to orchestrate a query.
|
618 |
+
|
619 |
+
Args:
|
620 |
+
query: The user's query to process
|
621 |
+
config: Optional agent configuration
|
622 |
+
|
623 |
+
Returns:
|
624 |
+
OrchestrationResult with all processing results
|
625 |
+
"""
|
626 |
+
orchestrator = GQueryOrchestrator(config)
|
627 |
+
return await orchestrator.orchestrate(query)
|
gquery/src/gquery/agents/query_analyzer.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Query Analysis and Intent Detection Module
|
3 |
+
|
4 |
+
Analyzes user queries to determine intent, extract entities, and plan database interactions.
|
5 |
+
This implements Feature 2.3 from the PRD.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from typing import Dict, List, Optional, Tuple
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from datetime import datetime
|
13 |
+
|
14 |
+
from openai import AsyncOpenAI
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
|
17 |
+
from .config import AgentConfig, QueryType, AGENT_PROMPTS
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class QueryEntity(BaseModel):
|
24 |
+
"""Represents an extracted biomedical entity."""
|
25 |
+
name: str
|
26 |
+
entity_type: str # gene, variant, disease, organism, other
|
27 |
+
confidence: float = Field(ge=0.0, le=1.0)
|
28 |
+
standardized_name: Optional[str] = None
|
29 |
+
identifiers: Dict[str, str] = Field(default_factory=dict)
|
30 |
+
synonyms: List[str] = Field(default_factory=list)
|
31 |
+
|
32 |
+
|
33 |
+
class QueryAnalysis(BaseModel):
|
34 |
+
"""Results of query analysis."""
|
35 |
+
query_type: QueryType
|
36 |
+
entities: List[QueryEntity]
|
37 |
+
databases_needed: List[str]
|
38 |
+
intent: str
|
39 |
+
complexity: str
|
40 |
+
confidence: float = Field(ge=0.0, le=1.0)
|
41 |
+
analysis_timestamp: datetime = Field(default_factory=datetime.now)
|
42 |
+
processing_time_ms: Optional[int] = None
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class DatabasePlan:
|
47 |
+
"""Plan for querying databases."""
|
48 |
+
database: str
|
49 |
+
priority: str
|
50 |
+
estimated_cost: float
|
51 |
+
expected_results: int
|
52 |
+
query_params: Dict
|
53 |
+
|
54 |
+
|
55 |
+
class QueryAnalyzer:
|
56 |
+
"""Analyzes user queries and extracts intent and entities."""
|
57 |
+
|
58 |
+
def __init__(self, config: AgentConfig):
|
59 |
+
self.config = config
|
60 |
+
self.client = AsyncOpenAI(api_key=config.openai_api_key)
|
61 |
+
self.logger = logging.getLogger(__name__)
|
62 |
+
|
63 |
+
async def analyze_query(self, query: str) -> QueryAnalysis:
|
64 |
+
"""
|
65 |
+
Analyze a user query to determine intent and extract entities.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
query: The user's natural language query
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
QueryAnalysis object with extracted information
|
72 |
+
"""
|
73 |
+
start_time = datetime.now()
|
74 |
+
|
75 |
+
try:
|
76 |
+
# Use LLM to analyze the query
|
77 |
+
analysis_result = await self._llm_analyze_query(query)
|
78 |
+
|
79 |
+
# Validate and structure the results
|
80 |
+
analysis = self._structure_analysis(analysis_result, query)
|
81 |
+
|
82 |
+
# Calculate processing time
|
83 |
+
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
84 |
+
analysis.processing_time_ms = int(processing_time)
|
85 |
+
|
86 |
+
self.logger.info(f"Query analyzed successfully in {processing_time:.2f}ms")
|
87 |
+
return analysis
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
self.logger.error(f"Query analysis failed: {e}")
|
91 |
+
# Return fallback analysis
|
92 |
+
return self._create_fallback_analysis(query)
|
93 |
+
|
94 |
+
async def _llm_analyze_query(self, query: str) -> Dict:
|
95 |
+
"""Use LLM to analyze the query."""
|
96 |
+
prompt = AGENT_PROMPTS["query_analysis"].format(query=query)
|
97 |
+
|
98 |
+
response = await self.client.chat.completions.create(
|
99 |
+
model=self.config.model,
|
100 |
+
messages=[{"role": "user", "content": prompt}],
|
101 |
+
temperature=self.config.temperature,
|
102 |
+
max_tokens=self.config.max_tokens,
|
103 |
+
response_format={"type": "json_object"}
|
104 |
+
)
|
105 |
+
|
106 |
+
return json.loads(response.choices[0].message.content)
|
107 |
+
|
108 |
+
def _structure_analysis(self, llm_result: Dict, original_query: str) -> QueryAnalysis:
|
109 |
+
"""Structure the LLM results into a QueryAnalysis object."""
|
110 |
+
|
111 |
+
# Extract entities
|
112 |
+
entities = []
|
113 |
+
if "entities" in llm_result:
|
114 |
+
# Mapping from LLM JSON keys (plural) to entity types (singular)
|
115 |
+
entity_type_mapping = {
|
116 |
+
"genes": "gene",
|
117 |
+
"variants": "variant",
|
118 |
+
"diseases": "disease",
|
119 |
+
"organisms": "organism",
|
120 |
+
"other": "other"
|
121 |
+
}
|
122 |
+
|
123 |
+
for json_key, entity_list in llm_result["entities"].items():
|
124 |
+
# Map plural JSON key to singular entity type
|
125 |
+
entity_type = entity_type_mapping.get(json_key, json_key)
|
126 |
+
|
127 |
+
for entity_name in entity_list:
|
128 |
+
if entity_name: # Skip empty strings
|
129 |
+
entities.append(QueryEntity(
|
130 |
+
name=entity_name,
|
131 |
+
entity_type=entity_type,
|
132 |
+
confidence=llm_result.get("confidence", 0.8)
|
133 |
+
))
|
134 |
+
|
135 |
+
# Map query type
|
136 |
+
query_type_str = llm_result.get("query_type", "gene_lookup")
|
137 |
+
try:
|
138 |
+
query_type = QueryType(query_type_str)
|
139 |
+
except ValueError:
|
140 |
+
query_type = QueryType.GENE_LOOKUP
|
141 |
+
|
142 |
+
# Ensure comprehensive database selection
|
143 |
+
databases_needed = llm_result.get("databases_needed", ["pmc", "clinvar", "datasets"])
|
144 |
+
|
145 |
+
# If only one database is selected, add others for comprehensive results
|
146 |
+
if len(databases_needed) == 1:
|
147 |
+
if "pmc" not in databases_needed:
|
148 |
+
databases_needed.append("pmc")
|
149 |
+
if "clinvar" not in databases_needed:
|
150 |
+
databases_needed.append("clinvar")
|
151 |
+
if "datasets" not in databases_needed:
|
152 |
+
databases_needed.append("datasets")
|
153 |
+
|
154 |
+
# Ensure at least PMC and one other database for most queries
|
155 |
+
if len(databases_needed) < 2:
|
156 |
+
databases_needed = ["pmc", "clinvar", "datasets"]
|
157 |
+
|
158 |
+
return QueryAnalysis(
|
159 |
+
query_type=query_type,
|
160 |
+
entities=entities,
|
161 |
+
databases_needed=databases_needed,
|
162 |
+
intent=llm_result.get("intent", "Gene lookup"),
|
163 |
+
complexity=llm_result.get("complexity", "simple"),
|
164 |
+
confidence=llm_result.get("confidence", 0.8)
|
165 |
+
)
|
166 |
+
|
167 |
+
def _create_fallback_analysis(self, query: str) -> QueryAnalysis:
|
168 |
+
"""Create a basic analysis when LLM fails."""
|
169 |
+
# Simple keyword-based fallback
|
170 |
+
entities = []
|
171 |
+
databases_needed = ["datasets"]
|
172 |
+
query_type = QueryType.GENE_LOOKUP
|
173 |
+
|
174 |
+
# Basic gene detection
|
175 |
+
gene_keywords = self._extract_potential_genes(query)
|
176 |
+
for gene in gene_keywords:
|
177 |
+
entities.append(QueryEntity(
|
178 |
+
name=gene,
|
179 |
+
entity_type="gene",
|
180 |
+
confidence=0.5
|
181 |
+
))
|
182 |
+
|
183 |
+
# Check for variant keywords
|
184 |
+
if any(term in query.lower() for term in ["variant", "mutation", "snp", "rs"]):
|
185 |
+
query_type = QueryType.VARIANT_ANALYSIS
|
186 |
+
databases_needed = ["clinvar", "datasets", "pmc"] # Include PMC for literature
|
187 |
+
|
188 |
+
# Check for literature keywords - but also include other databases for comprehensive search
|
189 |
+
elif any(term in query.lower() for term in ["research", "study", "paper", "literature", "findings", "role", "therapy", "treatment"]):
|
190 |
+
query_type = QueryType.LITERATURE_SEARCH
|
191 |
+
databases_needed = ["pmc", "clinvar", "datasets"] # Include all databases for comprehensive analysis
|
192 |
+
|
193 |
+
# For gene queries, include all databases by default
|
194 |
+
elif gene_keywords:
|
195 |
+
query_type = QueryType.GENE_LOOKUP
|
196 |
+
databases_needed = ["datasets", "clinvar", "pmc"] # All databases for comprehensive gene analysis
|
197 |
+
|
198 |
+
return QueryAnalysis(
|
199 |
+
query_type=query_type,
|
200 |
+
entities=entities,
|
201 |
+
databases_needed=databases_needed,
|
202 |
+
intent="Automated fallback analysis",
|
203 |
+
complexity="simple",
|
204 |
+
confidence=0.3
|
205 |
+
)
|
206 |
+
|
207 |
+
def _extract_potential_genes(self, query: str) -> List[str]:
|
208 |
+
"""Extract potential gene names using simple heuristics."""
|
209 |
+
import re
|
210 |
+
|
211 |
+
# Look for capitalized words that could be gene symbols
|
212 |
+
words = query.split()
|
213 |
+
potential_genes = []
|
214 |
+
|
215 |
+
for word in words:
|
216 |
+
# Clean word
|
217 |
+
clean_word = re.sub(r'[^\w]', '', word)
|
218 |
+
|
219 |
+
# Gene symbol patterns
|
220 |
+
if (len(clean_word) >= 2 and
|
221 |
+
clean_word.isupper() and
|
222 |
+
clean_word.isalpha()):
|
223 |
+
potential_genes.append(clean_word)
|
224 |
+
elif (len(clean_word) >= 3 and
|
225 |
+
clean_word[0].isupper() and
|
226 |
+
any(c.isupper() for c in clean_word[1:])):
|
227 |
+
potential_genes.append(clean_word)
|
228 |
+
|
229 |
+
return potential_genes
|
230 |
+
|
231 |
+
def create_database_plan(self, analysis: QueryAnalysis) -> List[DatabasePlan]:
|
232 |
+
"""Create a plan for querying databases based on analysis."""
|
233 |
+
from .config import DATABASE_PRIORITIES
|
234 |
+
|
235 |
+
plans = []
|
236 |
+
priorities = DATABASE_PRIORITIES.get(analysis.query_type, {})
|
237 |
+
|
238 |
+
for db_name in analysis.databases_needed:
|
239 |
+
priority = priorities.get(db_name, "medium")
|
240 |
+
|
241 |
+
# Estimate costs and results based on complexity and entities
|
242 |
+
entity_count = len(analysis.entities)
|
243 |
+
complexity_multiplier = {
|
244 |
+
"simple": 1.0,
|
245 |
+
"moderate": 2.0,
|
246 |
+
"complex": 4.0
|
247 |
+
}.get(analysis.complexity, 1.0)
|
248 |
+
|
249 |
+
estimated_cost = entity_count * complexity_multiplier
|
250 |
+
expected_results = int(entity_count * 10 * complexity_multiplier)
|
251 |
+
|
252 |
+
# Create query parameters
|
253 |
+
query_params = {
|
254 |
+
"entities": [e.name for e in analysis.entities],
|
255 |
+
"entity_types": [e.entity_type for e in analysis.entities],
|
256 |
+
"complexity": analysis.complexity
|
257 |
+
}
|
258 |
+
|
259 |
+
plans.append(DatabasePlan(
|
260 |
+
database=db_name,
|
261 |
+
priority=priority,
|
262 |
+
estimated_cost=estimated_cost,
|
263 |
+
expected_results=expected_results,
|
264 |
+
query_params=query_params
|
265 |
+
))
|
266 |
+
|
267 |
+
# Sort by priority (high first)
|
268 |
+
priority_order = {"high": 0, "medium": 1, "low": 2}
|
269 |
+
plans.sort(key=lambda p: priority_order.get(p.priority, 3))
|
270 |
+
|
271 |
+
return plans
|
272 |
+
|
273 |
+
|
274 |
+
async def analyze_query_intent(query: str, config: Optional[AgentConfig] = None) -> QueryAnalysis:
|
275 |
+
"""
|
276 |
+
Convenience function to analyze a query.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
query: The user's query to analyze
|
280 |
+
config: Optional agent configuration
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
QueryAnalysis results
|
284 |
+
"""
|
285 |
+
if config is None:
|
286 |
+
config = AgentConfig.from_env()
|
287 |
+
|
288 |
+
analyzer = QueryAnalyzer(config)
|
289 |
+
return await analyzer.analyze_query(query)
|
gquery/src/gquery/agents/synthesis.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Cross-Database Synthesis Engine
|
3 |
+
|
4 |
+
Synthesizes and correlates data from multiple biomedical databases.
|
5 |
+
This implements Feature 2.2 from the PRD.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from typing import Dict, List, Optional, Any
|
11 |
+
from datetime import datetime
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
from openai import AsyncOpenAI
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
|
17 |
+
from .config import AgentConfig, AGENT_PROMPTS
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class SynthesisInsight(BaseModel):
|
24 |
+
"""Represents a key insight from data synthesis."""
|
25 |
+
type: str # correlation, contradiction, gap, pattern
|
26 |
+
description: str
|
27 |
+
evidence: List[str]
|
28 |
+
confidence: float = Field(ge=0.0, le=1.0)
|
29 |
+
sources: List[str]
|
30 |
+
|
31 |
+
|
32 |
+
class SynthesisResult(BaseModel):
|
33 |
+
"""Results from cross-database synthesis."""
|
34 |
+
executive_summary: str
|
35 |
+
key_findings: List[str]
|
36 |
+
insights: List[SynthesisInsight]
|
37 |
+
correlations: Dict[str, List[str]]
|
38 |
+
gaps_and_limitations: List[str]
|
39 |
+
additional_resources: List[str] # Changed from recommendations to additional_resources
|
40 |
+
data_sources_used: List[str]
|
41 |
+
source_urls: Dict[str, List[str]] = Field(default_factory=dict) # Database -> list of URLs
|
42 |
+
synthesis_timestamp: datetime = Field(default_factory=datetime.now)
|
43 |
+
processing_time_ms: Optional[int] = None
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class DataSource:
|
48 |
+
"""Represents a data source for synthesis."""
|
49 |
+
name: str
|
50 |
+
data: Dict[str, Any]
|
51 |
+
quality_score: float
|
52 |
+
record_count: int
|
53 |
+
last_updated: Optional[datetime] = None
|
54 |
+
|
55 |
+
|
56 |
+
class DataSynthesizer:
|
57 |
+
"""Synthesizes data from multiple biomedical databases."""
|
58 |
+
|
59 |
+
def __init__(self, config: AgentConfig):
|
60 |
+
self.config = config
|
61 |
+
self.client = AsyncOpenAI(api_key=config.openai_api_key)
|
62 |
+
self.logger = logging.getLogger(__name__)
|
63 |
+
|
64 |
+
async def synthesize_data(
|
65 |
+
self,
|
66 |
+
query: str,
|
67 |
+
datasets_data: Optional[Dict] = None,
|
68 |
+
pmc_data: Optional[Dict] = None,
|
69 |
+
clinvar_data: Optional[Dict] = None
|
70 |
+
) -> SynthesisResult:
|
71 |
+
"""
|
72 |
+
Synthesize data from multiple sources to answer a query.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
query: Original user query
|
76 |
+
datasets_data: Data from NCBI Datasets
|
77 |
+
pmc_data: Data from PMC literature search
|
78 |
+
clinvar_data: Data from ClinVar
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
SynthesisResult with comprehensive analysis
|
82 |
+
"""
|
83 |
+
start_time = datetime.now()
|
84 |
+
|
85 |
+
try:
|
86 |
+
# Prepare data sources
|
87 |
+
data_sources = self._prepare_data_sources(
|
88 |
+
datasets_data, pmc_data, clinvar_data
|
89 |
+
)
|
90 |
+
|
91 |
+
if not data_sources:
|
92 |
+
return self._create_empty_synthesis(query)
|
93 |
+
|
94 |
+
# Perform synthesis using LLM
|
95 |
+
synthesis_result = await self._llm_synthesize(query, data_sources)
|
96 |
+
|
97 |
+
# Structure the results
|
98 |
+
structured_result = self._structure_synthesis_results(
|
99 |
+
synthesis_result, data_sources
|
100 |
+
)
|
101 |
+
|
102 |
+
# Calculate processing time
|
103 |
+
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
104 |
+
structured_result.processing_time_ms = int(processing_time)
|
105 |
+
|
106 |
+
self.logger.info(f"Data synthesis completed in {processing_time:.2f}ms")
|
107 |
+
return structured_result
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
self.logger.error(f"Data synthesis failed: {e}")
|
111 |
+
return self._create_error_synthesis(query, str(e))
|
112 |
+
|
113 |
+
def _prepare_data_sources(
|
114 |
+
self,
|
115 |
+
datasets_data: Optional[Dict],
|
116 |
+
pmc_data: Optional[Dict],
|
117 |
+
clinvar_data: Optional[Dict]
|
118 |
+
) -> List[DataSource]:
|
119 |
+
"""Prepare and quality-check data sources."""
|
120 |
+
sources = []
|
121 |
+
|
122 |
+
# Process Datasets data
|
123 |
+
if datasets_data and "gene_info" in datasets_data:
|
124 |
+
gene_info = datasets_data["gene_info"]
|
125 |
+
record_count = len(gene_info) if isinstance(gene_info, list) else 1
|
126 |
+
sources.append(DataSource(
|
127 |
+
name="NCBI Datasets",
|
128 |
+
data=datasets_data,
|
129 |
+
quality_score=0.9, # High quality genomic data
|
130 |
+
record_count=record_count
|
131 |
+
))
|
132 |
+
|
133 |
+
# Process PMC data
|
134 |
+
if pmc_data and "articles" in pmc_data:
|
135 |
+
articles = pmc_data["articles"]
|
136 |
+
record_count = len(articles) if isinstance(articles, list) else 0
|
137 |
+
if record_count > 0:
|
138 |
+
sources.append(DataSource(
|
139 |
+
name="PMC Literature",
|
140 |
+
data=pmc_data,
|
141 |
+
quality_score=0.8, # Good quality literature
|
142 |
+
record_count=record_count
|
143 |
+
))
|
144 |
+
|
145 |
+
# Process ClinVar data
|
146 |
+
if clinvar_data and "variants" in clinvar_data:
|
147 |
+
variants = clinvar_data["variants"]
|
148 |
+
record_count = len(variants) if isinstance(variants, list) else 0
|
149 |
+
if record_count > 0:
|
150 |
+
sources.append(DataSource(
|
151 |
+
name="ClinVar",
|
152 |
+
data=clinvar_data,
|
153 |
+
quality_score=0.85, # High quality clinical data
|
154 |
+
record_count=record_count
|
155 |
+
))
|
156 |
+
|
157 |
+
return sources
|
158 |
+
|
159 |
+
def _generate_source_urls(self, data_sources: List[DataSource]) -> Dict[str, List[str]]:
|
160 |
+
"""Generate actual URLs for source data."""
|
161 |
+
source_urls = {}
|
162 |
+
|
163 |
+
for source in data_sources:
|
164 |
+
urls = []
|
165 |
+
|
166 |
+
if source.name == "PMC Literature" and "articles" in source.data:
|
167 |
+
articles = source.data["articles"]
|
168 |
+
for article in articles[:10]: # Limit to first 10
|
169 |
+
if hasattr(article, 'pmc_id') and article.pmc_id:
|
170 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/pmc/articles/{article.pmc_id}/")
|
171 |
+
elif hasattr(article, 'article') and hasattr(article.article, 'pmc_id') and article.article.pmc_id:
|
172 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/pmc/articles/{article.article.pmc_id}/")
|
173 |
+
elif isinstance(article, dict) and article.get('pmc_id'):
|
174 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/pmc/articles/{article['pmc_id']}/")
|
175 |
+
|
176 |
+
elif source.name == "ClinVar" and "variants" in source.data:
|
177 |
+
variants = source.data["variants"]
|
178 |
+
for variant in variants[:10]: # Limit to first 10
|
179 |
+
if hasattr(variant, 'variation_id') and variant.variation_id:
|
180 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/clinvar/variation/{variant.variation_id}/")
|
181 |
+
elif isinstance(variant, dict) and variant.get('variation_id'):
|
182 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/clinvar/variation/{variant['variation_id']}/")
|
183 |
+
|
184 |
+
elif source.name == "NCBI Datasets" and "gene_info" in source.data:
|
185 |
+
gene_info = source.data["gene_info"]
|
186 |
+
if hasattr(gene_info, 'gene_id') and gene_info.gene_id:
|
187 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/gene/{gene_info.gene_id}")
|
188 |
+
elif isinstance(gene_info, dict) and gene_info.get('gene_id'):
|
189 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/gene/{gene_info['gene_id']}")
|
190 |
+
elif isinstance(gene_info, list) and gene_info:
|
191 |
+
first_gene = gene_info[0]
|
192 |
+
if hasattr(first_gene, 'gene_id') and first_gene.gene_id:
|
193 |
+
urls.append(f"https://www.ncbi.nlm.nih.gov/gene/{first_gene.gene_id}")
|
194 |
+
|
195 |
+
if urls:
|
196 |
+
source_urls[source.name] = urls
|
197 |
+
|
198 |
+
return source_urls
|
199 |
+
|
200 |
+
async def _llm_synthesize(self, query: str, data_sources: List[DataSource]) -> Dict:
|
201 |
+
"""Use LLM to synthesize the data."""
|
202 |
+
|
203 |
+
# Prepare data sources summary for the prompt
|
204 |
+
data_sources_text = ""
|
205 |
+
for source in data_sources:
|
206 |
+
data_sources_text += f"\n\n## {source.name} ({source.record_count} records)\n"
|
207 |
+
data_sources_text += f"Quality Score: {source.quality_score}\n"
|
208 |
+
data_sources_text += f"Data: {json.dumps(source.data, indent=2, default=str)[:2000]}..."
|
209 |
+
|
210 |
+
prompt = AGENT_PROMPTS["synthesis"].format(
|
211 |
+
query=query,
|
212 |
+
data_sources=data_sources_text
|
213 |
+
)
|
214 |
+
|
215 |
+
# Use multiple attempts for better synthesis
|
216 |
+
for attempt in range(self.config.max_retries):
|
217 |
+
try:
|
218 |
+
response = await self.client.chat.completions.create(
|
219 |
+
model=self.config.model,
|
220 |
+
messages=[{"role": "user", "content": prompt}],
|
221 |
+
temperature=self.config.temperature,
|
222 |
+
max_tokens=self.config.max_tokens
|
223 |
+
)
|
224 |
+
|
225 |
+
synthesis_text = response.choices[0].message.content
|
226 |
+
return self._parse_synthesis_response(synthesis_text)
|
227 |
+
|
228 |
+
except Exception as e:
|
229 |
+
self.logger.warning(f"Synthesis attempt {attempt + 1} failed: {e}")
|
230 |
+
if attempt == self.config.max_retries - 1:
|
231 |
+
raise
|
232 |
+
|
233 |
+
raise Exception("All synthesis attempts failed")
|
234 |
+
|
235 |
+
def _parse_synthesis_response(self, synthesis_text: str) -> Dict:
|
236 |
+
"""Parse the LLM synthesis response into structured data."""
|
237 |
+
|
238 |
+
# Try to extract structured sections
|
239 |
+
sections = {
|
240 |
+
"executive_summary": "",
|
241 |
+
"key_findings": [],
|
242 |
+
"insights": [],
|
243 |
+
"correlations": {},
|
244 |
+
"gaps_and_limitations": [],
|
245 |
+
"additional_resources": [] # Changed from recommendations
|
246 |
+
}
|
247 |
+
|
248 |
+
# Simple parsing - look for common section headers
|
249 |
+
lines = synthesis_text.split('\n')
|
250 |
+
current_section = None
|
251 |
+
|
252 |
+
for line in lines:
|
253 |
+
line = line.strip()
|
254 |
+
if not line:
|
255 |
+
continue
|
256 |
+
|
257 |
+
# Detect section headers
|
258 |
+
line_lower = line.lower()
|
259 |
+
if "executive summary" in line_lower:
|
260 |
+
current_section = "executive_summary"
|
261 |
+
continue
|
262 |
+
elif "key findings" in line_lower:
|
263 |
+
current_section = "key_findings"
|
264 |
+
continue
|
265 |
+
elif "limitations" in line_lower or "gaps" in line_lower:
|
266 |
+
current_section = "gaps_and_limitations"
|
267 |
+
continue
|
268 |
+
elif "additional" in line_lower and ("resources" in line_lower or "information" in line_lower):
|
269 |
+
current_section = "additional_resources"
|
270 |
+
continue
|
271 |
+
|
272 |
+
# Add content to current section
|
273 |
+
if current_section == "executive_summary":
|
274 |
+
sections["executive_summary"] += line + " "
|
275 |
+
elif current_section in ["key_findings", "gaps_and_limitations", "additional_resources"]:
|
276 |
+
if line.startswith(('-', '•', '*', '1.', '2.', '3.')):
|
277 |
+
# Remove bullet points/numbers
|
278 |
+
clean_line = line.lstrip('-•*123456789. ')
|
279 |
+
if clean_line:
|
280 |
+
sections[current_section].append(clean_line)
|
281 |
+
|
282 |
+
# If parsing failed, use the whole text as executive summary
|
283 |
+
if not sections["executive_summary"] and not sections["key_findings"]:
|
284 |
+
sections["executive_summary"] = synthesis_text[:500] + "..."
|
285 |
+
sections["key_findings"] = ["Comprehensive analysis provided in executive summary"]
|
286 |
+
|
287 |
+
return sections
|
288 |
+
|
289 |
+
def _structure_synthesis_results(
|
290 |
+
self,
|
291 |
+
synthesis_data: Dict,
|
292 |
+
data_sources: List[DataSource]
|
293 |
+
) -> SynthesisResult:
|
294 |
+
"""Structure the synthesis results into a SynthesisResult object."""
|
295 |
+
|
296 |
+
# Create insights from key findings
|
297 |
+
insights = []
|
298 |
+
for finding in synthesis_data.get("key_findings", []):
|
299 |
+
insights.append(SynthesisInsight(
|
300 |
+
type="pattern",
|
301 |
+
description=finding,
|
302 |
+
evidence=[finding],
|
303 |
+
confidence=0.8,
|
304 |
+
sources=[source.name for source in data_sources]
|
305 |
+
))
|
306 |
+
|
307 |
+
# Create correlations map
|
308 |
+
correlations = {}
|
309 |
+
for source in data_sources:
|
310 |
+
correlations[source.name] = [
|
311 |
+
f"{source.record_count} records",
|
312 |
+
f"Quality: {source.quality_score}"
|
313 |
+
]
|
314 |
+
|
315 |
+
return SynthesisResult(
|
316 |
+
executive_summary=synthesis_data.get("executive_summary", "").strip(),
|
317 |
+
key_findings=synthesis_data.get("key_findings", []),
|
318 |
+
insights=insights,
|
319 |
+
correlations=correlations,
|
320 |
+
gaps_and_limitations=synthesis_data.get("gaps_and_limitations", []),
|
321 |
+
additional_resources=synthesis_data.get("additional_resources", []),
|
322 |
+
data_sources_used=[source.name for source in data_sources],
|
323 |
+
source_urls=self._generate_source_urls(data_sources)
|
324 |
+
)
|
325 |
+
|
326 |
+
def _create_empty_synthesis(self, query: str) -> SynthesisResult:
|
327 |
+
"""Create an empty synthesis result when no data is available."""
|
328 |
+
return SynthesisResult(
|
329 |
+
executive_summary=f"No data available to synthesize for query: {query}",
|
330 |
+
key_findings=["No relevant data found across databases"],
|
331 |
+
insights=[],
|
332 |
+
correlations={},
|
333 |
+
gaps_and_limitations=["No data sources returned results"],
|
334 |
+
additional_resources=["Try refining query terms", "Check alternative gene symbols or identifiers"],
|
335 |
+
data_sources_used=[]
|
336 |
+
)
|
337 |
+
|
338 |
+
def _create_error_synthesis(self, query: str, error: str) -> SynthesisResult:
|
339 |
+
"""Create an error synthesis result."""
|
340 |
+
return SynthesisResult(
|
341 |
+
executive_summary=f"Synthesis failed for query: {query}. Error: {error}",
|
342 |
+
key_findings=["Synthesis process encountered an error"],
|
343 |
+
insights=[],
|
344 |
+
correlations={},
|
345 |
+
gaps_and_limitations=[f"Technical error: {error}"],
|
346 |
+
additional_resources=["Retry the query", "Contact support if error persists"],
|
347 |
+
data_sources_used=[]
|
348 |
+
)
|
349 |
+
|
350 |
+
async def cross_reference_entities(
|
351 |
+
self,
|
352 |
+
entities: List[str],
|
353 |
+
data_sources: List[DataSource]
|
354 |
+
) -> Dict[str, List[str]]:
|
355 |
+
"""Cross-reference entities across data sources."""
|
356 |
+
|
357 |
+
cross_references = {}
|
358 |
+
|
359 |
+
for entity in entities:
|
360 |
+
entity_refs = []
|
361 |
+
|
362 |
+
for source in data_sources:
|
363 |
+
# Simple text search for entity mentions
|
364 |
+
source_text = json.dumps(source.data, default=str).lower()
|
365 |
+
entity_lower = entity.lower()
|
366 |
+
|
367 |
+
if entity_lower in source_text:
|
368 |
+
entity_refs.append(f"Found in {source.name}")
|
369 |
+
|
370 |
+
if entity_refs:
|
371 |
+
cross_references[entity] = entity_refs
|
372 |
+
|
373 |
+
return cross_references
|
374 |
+
|
375 |
+
async def identify_data_gaps(self, data_sources: List[DataSource]) -> List[str]:
|
376 |
+
"""Identify gaps in the available data."""
|
377 |
+
|
378 |
+
gaps = []
|
379 |
+
|
380 |
+
# Check for missing data types
|
381 |
+
source_names = [source.name for source in data_sources]
|
382 |
+
|
383 |
+
if "NCBI Datasets" not in source_names:
|
384 |
+
gaps.append("Missing genomic data from NCBI Datasets")
|
385 |
+
|
386 |
+
if "PMC Literature" not in source_names:
|
387 |
+
gaps.append("Missing literature data from PMC")
|
388 |
+
|
389 |
+
if "ClinVar" not in source_names:
|
390 |
+
gaps.append("Missing clinical variant data from ClinVar")
|
391 |
+
|
392 |
+
# Check for low record counts
|
393 |
+
for source in data_sources:
|
394 |
+
if source.record_count == 0:
|
395 |
+
gaps.append(f"No records returned from {source.name}")
|
396 |
+
elif source.record_count < 5:
|
397 |
+
gaps.append(f"Limited data from {source.name} ({source.record_count} records)")
|
398 |
+
|
399 |
+
return gaps
|
400 |
+
|
401 |
+
|
402 |
+
# Convenience function for data synthesis
|
403 |
+
async def synthesize_biomedical_data(
|
404 |
+
query: str,
|
405 |
+
datasets_data: Optional[Dict] = None,
|
406 |
+
pmc_data: Optional[Dict] = None,
|
407 |
+
clinvar_data: Optional[Dict] = None,
|
408 |
+
config: Optional[AgentConfig] = None
|
409 |
+
) -> SynthesisResult:
|
410 |
+
"""
|
411 |
+
Convenience function to synthesize biomedical data.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
query: Original user query
|
415 |
+
datasets_data: Data from NCBI Datasets
|
416 |
+
pmc_data: Data from PMC
|
417 |
+
clinvar_data: Data from ClinVar
|
418 |
+
config: Optional agent configuration
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
SynthesisResult with comprehensive analysis
|
422 |
+
"""
|
423 |
+
if config is None:
|
424 |
+
config = AgentConfig.from_env()
|
425 |
+
|
426 |
+
synthesizer = DataSynthesizer(config)
|
427 |
+
return await synthesizer.synthesize_data(
|
428 |
+
query, datasets_data, pmc_data, clinvar_data
|
429 |
+
)
|
gquery/src/gquery/cli.py
ADDED
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Command-line interface for GQuery AI.
|
3 |
+
|
4 |
+
This module provides the main CLI entry point and commands
|
5 |
+
for running the application and utilities.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import typer
|
13 |
+
from rich.console import Console
|
14 |
+
from rich.table import Table
|
15 |
+
|
16 |
+
from gquery.config.settings import get_settings
|
17 |
+
from gquery.tools.pmc_client import PMCClient
|
18 |
+
from gquery.tools.clinvar_client import ClinVarClient
|
19 |
+
from gquery.tools.datasets_client import DatasetsClient
|
20 |
+
from gquery.utils.logger import get_logger, setup_logging
|
21 |
+
|
22 |
+
# Initialize CLI app
|
23 |
+
app = typer.Typer(
|
24 |
+
name="gquery",
|
25 |
+
help="GQuery AI - Biomedical Research Platform",
|
26 |
+
add_completion=False,
|
27 |
+
)
|
28 |
+
|
29 |
+
console = Console()
|
30 |
+
logger = get_logger("cli")
|
31 |
+
|
32 |
+
|
33 |
+
@app.command()
|
34 |
+
def version() -> None:
|
35 |
+
"""Show version information."""
|
36 |
+
settings = get_settings()
|
37 |
+
console.print(f"GQuery AI v{settings.version}")
|
38 |
+
|
39 |
+
|
40 |
+
@app.command()
|
41 |
+
def config() -> None:
|
42 |
+
"""Show current configuration."""
|
43 |
+
settings = get_settings()
|
44 |
+
|
45 |
+
table = Table(title="GQuery AI Configuration")
|
46 |
+
table.add_column("Setting", style="cyan")
|
47 |
+
table.add_column("Value", style="green")
|
48 |
+
|
49 |
+
table.add_row("App Name", settings.app_name)
|
50 |
+
table.add_row("Version", settings.version)
|
51 |
+
table.add_row("Environment", settings.environment)
|
52 |
+
table.add_row("Debug", str(settings.debug))
|
53 |
+
table.add_row("Host", settings.host)
|
54 |
+
table.add_row("Port", str(settings.port))
|
55 |
+
|
56 |
+
console.print(table)
|
57 |
+
|
58 |
+
|
59 |
+
@app.command()
|
60 |
+
def serve(
|
61 |
+
host: Optional[str] = typer.Option(None, help="Host to bind to"),
|
62 |
+
port: Optional[int] = typer.Option(None, help="Port to bind to"),
|
63 |
+
workers: Optional[int] = typer.Option(None, help="Number of workers"),
|
64 |
+
reload: bool = typer.Option(False, help="Enable auto-reload"),
|
65 |
+
) -> None:
|
66 |
+
"""Start the API server."""
|
67 |
+
import uvicorn
|
68 |
+
|
69 |
+
settings = get_settings()
|
70 |
+
|
71 |
+
# Setup logging
|
72 |
+
setup_logging(
|
73 |
+
level=settings.logging.level,
|
74 |
+
format_type=settings.logging.format,
|
75 |
+
file_enabled=settings.logging.file_enabled,
|
76 |
+
file_path=settings.logging.file_path,
|
77 |
+
console_enabled=settings.logging.console_enabled,
|
78 |
+
)
|
79 |
+
|
80 |
+
# Use provided values or fall back to settings
|
81 |
+
server_host = host or settings.host
|
82 |
+
server_port = port or settings.port
|
83 |
+
server_workers = workers or settings.workers
|
84 |
+
|
85 |
+
console.print(f"Starting GQuery AI server on {server_host}:{server_port}")
|
86 |
+
|
87 |
+
if reload:
|
88 |
+
# Development mode with reload
|
89 |
+
uvicorn.run(
|
90 |
+
"gquery.api.main:app",
|
91 |
+
host=server_host,
|
92 |
+
port=server_port,
|
93 |
+
reload=True,
|
94 |
+
log_level=settings.logging.level.lower(),
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
# Production mode
|
98 |
+
uvicorn.run(
|
99 |
+
"gquery.api.main:app",
|
100 |
+
host=server_host,
|
101 |
+
port=server_port,
|
102 |
+
workers=server_workers,
|
103 |
+
log_level=settings.logging.level.lower(),
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
@app.command()
|
108 |
+
def test(
|
109 |
+
path: Optional[str] = typer.Option(None, help="Test path"),
|
110 |
+
coverage: bool = typer.Option(False, help="Run with coverage"),
|
111 |
+
verbose: bool = typer.Option(False, help="Verbose output"),
|
112 |
+
) -> None:
|
113 |
+
"""Run tests."""
|
114 |
+
import subprocess
|
115 |
+
import sys
|
116 |
+
|
117 |
+
cmd = ["python", "-m", "pytest"]
|
118 |
+
|
119 |
+
if path:
|
120 |
+
cmd.append(path)
|
121 |
+
else:
|
122 |
+
cmd.append("gquery/tests")
|
123 |
+
|
124 |
+
if coverage:
|
125 |
+
cmd.extend(["--cov=gquery", "--cov-report=html", "--cov-report=term"])
|
126 |
+
|
127 |
+
if verbose:
|
128 |
+
cmd.append("-v")
|
129 |
+
|
130 |
+
console.print(f"Running: {' '.join(cmd)}")
|
131 |
+
result = subprocess.run(cmd)
|
132 |
+
sys.exit(result.returncode)
|
133 |
+
|
134 |
+
|
135 |
+
@app.command()
|
136 |
+
def lint() -> None:
|
137 |
+
"""Run code linting."""
|
138 |
+
import subprocess
|
139 |
+
import sys
|
140 |
+
|
141 |
+
commands = [
|
142 |
+
["python", "-m", "black", "--check", "gquery/"],
|
143 |
+
["python", "-m", "isort", "--check-only", "gquery/"],
|
144 |
+
["python", "-m", "mypy", "gquery/src"],
|
145 |
+
]
|
146 |
+
|
147 |
+
for cmd in commands:
|
148 |
+
console.print(f"Running: {' '.join(cmd)}")
|
149 |
+
result = subprocess.run(cmd)
|
150 |
+
if result.returncode != 0:
|
151 |
+
console.print(f"[red]Command failed: {' '.join(cmd)}[/red]")
|
152 |
+
sys.exit(result.returncode)
|
153 |
+
|
154 |
+
console.print("[green]All linting checks passed![/green]")
|
155 |
+
|
156 |
+
|
157 |
+
@app.command()
|
158 |
+
def format() -> None:
|
159 |
+
"""Format code."""
|
160 |
+
import subprocess
|
161 |
+
|
162 |
+
commands = [
|
163 |
+
["python", "-m", "black", "gquery/"],
|
164 |
+
["python", "-m", "isort", "gquery/"],
|
165 |
+
]
|
166 |
+
|
167 |
+
for cmd in commands:
|
168 |
+
console.print(f"Running: {' '.join(cmd)}")
|
169 |
+
subprocess.run(cmd)
|
170 |
+
|
171 |
+
console.print("[green]Code formatting complete![/green]")
|
172 |
+
|
173 |
+
|
174 |
+
@app.command()
|
175 |
+
def init_db() -> None:
|
176 |
+
"""Initialize database."""
|
177 |
+
console.print("[yellow]Database initialization not implemented yet[/yellow]")
|
178 |
+
# TODO: Implement database initialization
|
179 |
+
|
180 |
+
|
181 |
+
@app.command()
|
182 |
+
def health() -> None:
|
183 |
+
"""Check system health."""
|
184 |
+
settings = get_settings()
|
185 |
+
|
186 |
+
table = Table(title="System Health Check")
|
187 |
+
table.add_column("Component", style="cyan")
|
188 |
+
table.add_column("Status", style="green")
|
189 |
+
|
190 |
+
# Check configuration
|
191 |
+
table.add_row("Configuration", "✓ OK")
|
192 |
+
|
193 |
+
# Check log directory
|
194 |
+
log_path = Path(settings.logging.file_path)
|
195 |
+
if log_path.parent.exists():
|
196 |
+
table.add_row("Log Directory", "✓ OK")
|
197 |
+
else:
|
198 |
+
table.add_row("Log Directory", "✗ Missing")
|
199 |
+
|
200 |
+
# Check NCBI API key
|
201 |
+
if settings.ncbi.api_key:
|
202 |
+
table.add_row("NCBI API Key", "✓ Configured")
|
203 |
+
else:
|
204 |
+
table.add_row("NCBI API Key", "⚠ Missing")
|
205 |
+
|
206 |
+
# Check NCBI Email
|
207 |
+
if settings.ncbi.email:
|
208 |
+
table.add_row("NCBI Email", "✓ Configured")
|
209 |
+
else:
|
210 |
+
table.add_row("NCBI Email", "⚠ Missing")
|
211 |
+
|
212 |
+
# Database and Redis are future features
|
213 |
+
table.add_row("Database", "⚠ Future feature (Phase 3)")
|
214 |
+
table.add_row("Redis Cache", "⚠ Future feature (Phase 3)")
|
215 |
+
|
216 |
+
console.print(table)
|
217 |
+
|
218 |
+
|
219 |
+
@app.command()
|
220 |
+
def test_pmc(
|
221 |
+
query: str = typer.Option("BRCA1 AND functional study", help="Search query"),
|
222 |
+
max_results: int = typer.Option(5, help="Maximum number of results"),
|
223 |
+
pmc_id: Optional[str] = typer.Option(None, help="Specific PMC ID to retrieve"),
|
224 |
+
) -> None:
|
225 |
+
"""Test PMC API functionality."""
|
226 |
+
|
227 |
+
async def run_pmc_test():
|
228 |
+
"""Run PMC API test."""
|
229 |
+
settings = get_settings()
|
230 |
+
|
231 |
+
# Setup logging
|
232 |
+
setup_logging(
|
233 |
+
level=settings.logging.level,
|
234 |
+
format_type=settings.logging.format,
|
235 |
+
file_enabled=settings.logging.file_enabled,
|
236 |
+
file_path=settings.logging.file_path,
|
237 |
+
console_enabled=settings.logging.console_enabled,
|
238 |
+
)
|
239 |
+
|
240 |
+
console.print(f"[bold blue]Testing PMC API[/bold blue]")
|
241 |
+
console.print(f"Query: {query}")
|
242 |
+
console.print(f"Max results: {max_results}")
|
243 |
+
|
244 |
+
try:
|
245 |
+
async with PMCClient() as client:
|
246 |
+
if pmc_id:
|
247 |
+
# Test specific article retrieval
|
248 |
+
console.print(f"\n[bold]Retrieving article: {pmc_id}[/bold]")
|
249 |
+
article = await client.get_article_content(pmc_id)
|
250 |
+
|
251 |
+
table = Table(title=f"Article: {pmc_id}")
|
252 |
+
table.add_column("Field", style="cyan")
|
253 |
+
table.add_column("Value", style="green")
|
254 |
+
|
255 |
+
table.add_row("Title", article.title[:100] + "..." if len(article.title) > 100 else article.title)
|
256 |
+
table.add_row("Authors", ", ".join(article.authors[:3]) + "..." if len(article.authors) > 3 else ", ".join(article.authors))
|
257 |
+
table.add_row("Journal", article.journal or "N/A")
|
258 |
+
table.add_row("DOI", article.doi or "N/A")
|
259 |
+
table.add_row("Genes", ", ".join(article.genes[:5]) + "..." if len(article.genes) > 5 else ", ".join(article.genes))
|
260 |
+
table.add_row("Variants", ", ".join(article.variants[:5]) + "..." if len(article.variants) > 5 else ", ".join(article.variants))
|
261 |
+
table.add_row("Diseases", ", ".join(article.diseases[:5]) + "..." if len(article.diseases) > 5 else ", ".join(article.diseases))
|
262 |
+
|
263 |
+
console.print(table)
|
264 |
+
|
265 |
+
else:
|
266 |
+
# Test search functionality
|
267 |
+
console.print(f"\n[bold]Searching articles[/bold]")
|
268 |
+
results = await client.search_articles(query, max_results=max_results)
|
269 |
+
|
270 |
+
table = Table(title=f"Search Results: {results.total_count} total")
|
271 |
+
table.add_column("PMC ID", style="cyan")
|
272 |
+
table.add_column("Title", style="green")
|
273 |
+
table.add_column("Relevance", style="yellow")
|
274 |
+
table.add_column("Genes", style="blue")
|
275 |
+
table.add_column("Variants", style="magenta")
|
276 |
+
|
277 |
+
for result in results.results:
|
278 |
+
genes_str = ", ".join(result.article.genes[:3]) + "..." if len(result.article.genes) > 3 else ", ".join(result.article.genes)
|
279 |
+
variants_str = ", ".join(result.article.variants[:2]) + "..." if len(result.article.variants) > 2 else ", ".join(result.article.variants)
|
280 |
+
|
281 |
+
table.add_row(
|
282 |
+
result.article.pmc_id,
|
283 |
+
result.article.title[:80] + "..." if len(result.article.title) > 80 else result.article.title,
|
284 |
+
f"{result.relevance_score:.2f}",
|
285 |
+
genes_str,
|
286 |
+
variants_str,
|
287 |
+
)
|
288 |
+
|
289 |
+
console.print(table)
|
290 |
+
|
291 |
+
# Show search metadata
|
292 |
+
console.print(f"\n[bold]Search Metadata[/bold]")
|
293 |
+
console.print(f"Processing time: {results.processing_time_ms:.2f}ms")
|
294 |
+
console.print(f"Average relevance: {results.average_relevance_score:.2f}")
|
295 |
+
console.print(f"Page: {results.page}/{results.total_pages}")
|
296 |
+
|
297 |
+
console.print("\n[bold green]✓ PMC API test completed successfully![/bold green]")
|
298 |
+
|
299 |
+
except Exception as e:
|
300 |
+
console.print(f"\n[bold red]✗ PMC API test failed: {e}[/bold red]")
|
301 |
+
logger.error("PMC API test failed", error=str(e))
|
302 |
+
raise typer.Exit(1)
|
303 |
+
|
304 |
+
# Run the async test
|
305 |
+
asyncio.run(run_pmc_test())
|
306 |
+
|
307 |
+
|
308 |
+
@app.command()
|
309 |
+
def test_clinvar(
|
310 |
+
gene: str = typer.Option("BRCA1", help="Gene symbol to test"),
|
311 |
+
max_results: int = typer.Option(10, help="Maximum results to retrieve"),
|
312 |
+
verbose: bool = typer.Option(False, help="Show detailed output"),
|
313 |
+
) -> None:
|
314 |
+
"""Test ClinVar API integration."""
|
315 |
+
from gquery.utils.cache import get_cache_manager
|
316 |
+
|
317 |
+
setup_logging()
|
318 |
+
console.print("[bold]Testing ClinVar API Integration[/bold]")
|
319 |
+
|
320 |
+
async def run_clinvar_test():
|
321 |
+
try:
|
322 |
+
cache_manager = get_cache_manager()
|
323 |
+
|
324 |
+
async with ClinVarClient(cache_manager=cache_manager) as client:
|
325 |
+
console.print(f"\n[bold]Testing ClinVar search for gene: {gene}[/bold]")
|
326 |
+
|
327 |
+
# Test 1: Search variants by gene
|
328 |
+
console.print(f"1. Searching for {gene} variants...")
|
329 |
+
results = await client.search_variants_by_gene(
|
330 |
+
gene_symbol=gene,
|
331 |
+
max_results=max_results,
|
332 |
+
)
|
333 |
+
|
334 |
+
console.print(f"[green]✓ Found {results.total_count} variants total, showing {len(results.results)} results[/green]")
|
335 |
+
|
336 |
+
if verbose and results.results:
|
337 |
+
table = Table(title=f"ClinVar Variants for {gene}")
|
338 |
+
table.add_column("Variation ID", style="cyan")
|
339 |
+
table.add_column("Name", style="white")
|
340 |
+
table.add_column("Clinical Significance", style="red")
|
341 |
+
table.add_column("Review Status", style="yellow")
|
342 |
+
table.add_column("Star Rating", style="green")
|
343 |
+
table.add_column("Gene", style="blue")
|
344 |
+
|
345 |
+
for result in results.results:
|
346 |
+
variant = result.variant
|
347 |
+
table.add_row(
|
348 |
+
variant.variation_id,
|
349 |
+
variant.name[:60] + "..." if len(variant.name) > 60 else variant.name,
|
350 |
+
variant.clinical_significance.value,
|
351 |
+
variant.review_status.value[:30] + "..." if len(variant.review_status.value) > 30 else variant.review_status.value,
|
352 |
+
f"{variant.star_rating}/4",
|
353 |
+
variant.gene_symbol or "N/A",
|
354 |
+
)
|
355 |
+
|
356 |
+
console.print(table)
|
357 |
+
|
358 |
+
# Show distribution of clinical significance
|
359 |
+
console.print(f"\n[bold]Clinical Significance Distribution[/bold]")
|
360 |
+
console.print(f"Pathogenic/Likely pathogenic: {results.pathogenic_count} ({results.pathogenic_percentage:.1f}%)")
|
361 |
+
console.print(f"Benign/Likely benign: {results.benign_count} ({results.benign_percentage:.1f}%)")
|
362 |
+
console.print(f"Average star rating: {results.average_star_rating:.1f}/4")
|
363 |
+
|
364 |
+
# Test 2: Get detailed variant information for first result
|
365 |
+
if results.results:
|
366 |
+
first_variant = results.results[0].variant
|
367 |
+
console.print(f"\n2. Getting detailed information for variant {first_variant.variation_id}...")
|
368 |
+
|
369 |
+
try:
|
370 |
+
detailed_variant = await client.get_variant_details(first_variant.variation_id)
|
371 |
+
console.print(f"[green]✓ Retrieved detailed information for {detailed_variant.name}[/green]")
|
372 |
+
|
373 |
+
if verbose:
|
374 |
+
console.print(f" - HGVS Genomic: {detailed_variant.hgvs_genomic or 'N/A'}")
|
375 |
+
console.print(f" - HGVS Coding: {detailed_variant.hgvs_coding or 'N/A'}")
|
376 |
+
console.print(f" - HGVS Protein: {detailed_variant.hgvs_protein or 'N/A'}")
|
377 |
+
console.print(f" - ClinVar URL: {detailed_variant.clinvar_url}")
|
378 |
+
|
379 |
+
except Exception as e:
|
380 |
+
console.print(f"[yellow]⚠ Could not get detailed info: {e}[/yellow]")
|
381 |
+
|
382 |
+
# Test 3: Search by variant name (if we have one)
|
383 |
+
if results.results and results.results[0].variant.name:
|
384 |
+
variant_name = results.results[0].variant.name.split()[0] # Take first word
|
385 |
+
console.print(f"\n3. Testing variant name search with '{variant_name}'...")
|
386 |
+
|
387 |
+
try:
|
388 |
+
name_results = await client.search_variant_by_name(
|
389 |
+
variant_name=variant_name,
|
390 |
+
gene_symbol=gene,
|
391 |
+
max_results=5,
|
392 |
+
)
|
393 |
+
console.print(f"[green]✓ Found {len(name_results)} variants by name[/green]")
|
394 |
+
|
395 |
+
except Exception as e:
|
396 |
+
console.print(f"[yellow]⚠ Variant name search failed: {e}[/yellow]")
|
397 |
+
|
398 |
+
# Show search metadata
|
399 |
+
console.print(f"\n[bold]Search Metadata[/bold]")
|
400 |
+
console.print(f"Processing time: {results.processing_time_ms:.2f}ms")
|
401 |
+
console.print(f"Page: {results.page}/{results.total_pages}")
|
402 |
+
|
403 |
+
console.print("\n[bold green]✓ ClinVar API test completed successfully![/bold green]")
|
404 |
+
|
405 |
+
except Exception as e:
|
406 |
+
console.print(f"\n[bold red]✗ ClinVar API test failed: {e}[/bold red]")
|
407 |
+
logger.error("ClinVar API test failed", error=str(e))
|
408 |
+
raise typer.Exit(1)
|
409 |
+
|
410 |
+
# Run the async test
|
411 |
+
asyncio.run(run_clinvar_test())
|
412 |
+
|
413 |
+
|
414 |
+
@app.command()
|
415 |
+
def test_datasets(
|
416 |
+
gene: str = typer.Option("BRCA1", help="Gene symbol to test"),
|
417 |
+
taxon_id: int = typer.Option(9606, help="NCBI taxonomy ID (default: 9606 for human)"),
|
418 |
+
gene_id: Optional[str] = typer.Option(None, help="Specific gene ID to test"),
|
419 |
+
accession: Optional[str] = typer.Option(None, help="Specific accession to test"),
|
420 |
+
verbose: bool = typer.Option(False, help="Show detailed output"),
|
421 |
+
) -> None:
|
422 |
+
"""Test NCBI Datasets API integration."""
|
423 |
+
|
424 |
+
setup_logging()
|
425 |
+
console.print("[bold]Testing NCBI Datasets API Integration[/bold]")
|
426 |
+
|
427 |
+
async def run_datasets_test():
|
428 |
+
try:
|
429 |
+
async with DatasetsClient() as client:
|
430 |
+
console.print(f"\n[bold]Testing NCBI Datasets API for gene: {gene}[/bold]")
|
431 |
+
|
432 |
+
# Initialize to avoid unbound variable error
|
433 |
+
gene_response = None
|
434 |
+
|
435 |
+
# Test 1: Get gene by symbol
|
436 |
+
console.print(f"1. Getting gene info by symbol: {gene} (taxon: {taxon_id})...")
|
437 |
+
try:
|
438 |
+
gene_response = await client.get_gene_by_symbol(
|
439 |
+
symbol=gene,
|
440 |
+
taxon_id=taxon_id
|
441 |
+
)
|
442 |
+
|
443 |
+
if gene_response.genes:
|
444 |
+
gene_info = gene_response.genes[0]
|
445 |
+
console.print(f"[green]✓ Found gene: {gene_info.symbol} (ID: {gene_info.gene_id})[/green]")
|
446 |
+
|
447 |
+
if verbose:
|
448 |
+
table = Table(title=f"Gene Information: {gene_info.symbol}")
|
449 |
+
table.add_column("Field", style="cyan")
|
450 |
+
table.add_column("Value", style="green")
|
451 |
+
|
452 |
+
table.add_row("Gene ID", str(gene_info.gene_id) if gene_info.gene_id else "N/A")
|
453 |
+
table.add_row("Symbol", gene_info.symbol or "N/A")
|
454 |
+
table.add_row("Description", gene_info.description[:100] + "..." if gene_info.description and len(gene_info.description) > 100 else gene_info.description or "N/A")
|
455 |
+
table.add_row("Organism", gene_info.organism_name or "N/A")
|
456 |
+
table.add_row("Tax ID", str(gene_info.tax_id) if gene_info.tax_id else "N/A")
|
457 |
+
table.add_row("Chromosome", gene_info.chromosome or "N/A")
|
458 |
+
table.add_row("Map Location", gene_info.map_location or "N/A")
|
459 |
+
table.add_row("Gene Type", gene_info.gene_type or "N/A")
|
460 |
+
table.add_row("Synonyms", ", ".join(gene_info.synonyms[:5]) + "..." if gene_info.synonyms and len(gene_info.synonyms) > 5 else ", ".join(gene_info.synonyms) if gene_info.synonyms else "N/A")
|
461 |
+
table.add_row("Transcripts", str(len(gene_info.transcripts)) if gene_info.transcripts else "0")
|
462 |
+
|
463 |
+
console.print(table)
|
464 |
+
|
465 |
+
# Show transcript information
|
466 |
+
if gene_info.transcripts:
|
467 |
+
console.print(f"\n[bold]Transcripts ({len(gene_info.transcripts)} total)[/bold]")
|
468 |
+
transcript_table = Table()
|
469 |
+
transcript_table.add_column("Accession", style="cyan")
|
470 |
+
transcript_table.add_column("Product", style="green")
|
471 |
+
transcript_table.add_column("Length", style="yellow")
|
472 |
+
|
473 |
+
for transcript in gene_info.transcripts[:5]: # Show first 5
|
474 |
+
transcript_table.add_row(
|
475 |
+
transcript.accession_version or "N/A",
|
476 |
+
transcript.product[:50] + "..." if transcript.product and len(transcript.product) > 50 else transcript.product or "N/A",
|
477 |
+
str(transcript.length) if transcript.length else "N/A"
|
478 |
+
)
|
479 |
+
|
480 |
+
console.print(transcript_table)
|
481 |
+
|
482 |
+
# Test NCBI links generation
|
483 |
+
console.print(f"\n2. Generating NCBI resource links...")
|
484 |
+
links = client.generate_ncbi_links(gene_info)
|
485 |
+
console.print(f"[green]✓ Generated resource links[/green]")
|
486 |
+
|
487 |
+
if verbose:
|
488 |
+
links_table = Table(title="NCBI Resource Links")
|
489 |
+
links_table.add_column("Resource", style="cyan")
|
490 |
+
links_table.add_column("URL", style="blue")
|
491 |
+
|
492 |
+
if links.gene_url:
|
493 |
+
links_table.add_row("Gene", links.gene_url)
|
494 |
+
if links.pubmed_url:
|
495 |
+
links_table.add_row("PubMed", links.pubmed_url)
|
496 |
+
if links.clinvar_url:
|
497 |
+
links_table.add_row("ClinVar", links.clinvar_url)
|
498 |
+
if links.dbsnp_url:
|
499 |
+
links_table.add_row("dbSNP", links.dbsnp_url)
|
500 |
+
if links.omim_url:
|
501 |
+
links_table.add_row("OMIM", links.omim_url)
|
502 |
+
|
503 |
+
console.print(links_table)
|
504 |
+
|
505 |
+
# Test reference sequences
|
506 |
+
console.print(f"\n3. Getting reference sequences...")
|
507 |
+
ref_seqs = await client.get_reference_sequences(gene_info)
|
508 |
+
console.print(f"[green]✓ Found {len(ref_seqs)} reference sequences[/green]")
|
509 |
+
|
510 |
+
if verbose and ref_seqs:
|
511 |
+
ref_table = Table(title="Reference Sequences")
|
512 |
+
ref_table.add_column("Accession", style="cyan")
|
513 |
+
ref_table.add_column("Type", style="yellow")
|
514 |
+
ref_table.add_column("Description", style="green")
|
515 |
+
|
516 |
+
for ref_seq in ref_seqs[:5]: # Show first 5
|
517 |
+
ref_table.add_row(
|
518 |
+
ref_seq.accession,
|
519 |
+
ref_seq.sequence_type,
|
520 |
+
ref_seq.description[:60] + "..." if len(ref_seq.description) > 60 else ref_seq.description
|
521 |
+
)
|
522 |
+
|
523 |
+
console.print(ref_table)
|
524 |
+
|
525 |
+
else:
|
526 |
+
console.print(f"[yellow]⚠ No gene data found for {gene}[/yellow]")
|
527 |
+
|
528 |
+
except Exception as e:
|
529 |
+
console.print(f"[red]✗ Gene symbol search failed: {e}[/red]")
|
530 |
+
|
531 |
+
# Test 2: Get gene by ID (if provided or found)
|
532 |
+
test_gene_id = gene_id
|
533 |
+
if not test_gene_id and gene_response.genes:
|
534 |
+
test_gene_id = str(gene_response.genes[0].gene_id)
|
535 |
+
|
536 |
+
if test_gene_id:
|
537 |
+
console.print(f"\n4. Testing gene retrieval by ID: {test_gene_id}...")
|
538 |
+
try:
|
539 |
+
id_response = await client.get_gene_by_id(test_gene_id)
|
540 |
+
if id_response.genes:
|
541 |
+
console.print(f"[green]✓ Retrieved gene by ID: {id_response.genes[0].symbol}[/green]")
|
542 |
+
else:
|
543 |
+
console.print(f"[yellow]⚠ No gene found for ID {test_gene_id}[/yellow]")
|
544 |
+
except Exception as e:
|
545 |
+
console.print(f"[yellow]⚠ Gene ID search failed: {e}[/yellow]")
|
546 |
+
|
547 |
+
# Test 3: Get gene by accession (if provided)
|
548 |
+
if accession:
|
549 |
+
console.print(f"\n5. Testing gene retrieval by accession: {accession}...")
|
550 |
+
try:
|
551 |
+
acc_response = await client.get_gene_by_accession(accession)
|
552 |
+
if acc_response.genes:
|
553 |
+
console.print(f"[green]✓ Retrieved gene by accession: {acc_response.genes[0].symbol}[/green]")
|
554 |
+
else:
|
555 |
+
console.print(f"[yellow]⚠ No gene found for accession {accession}[/yellow]")
|
556 |
+
except Exception as e:
|
557 |
+
console.print(f"[yellow]⚠ Gene accession search failed: {e}[/yellow]")
|
558 |
+
|
559 |
+
console.print("\n[bold green]✓ NCBI Datasets API test completed successfully![/bold green]")
|
560 |
+
|
561 |
+
except Exception as e:
|
562 |
+
console.print(f"\n[bold red]✗ NCBI Datasets API test failed: {e}[/bold red]")
|
563 |
+
logger.error("Datasets API test failed", error=str(e))
|
564 |
+
raise typer.Exit(1)
|
565 |
+
|
566 |
+
# Run the async test
|
567 |
+
asyncio.run(run_datasets_test())
|
568 |
+
|
569 |
+
|
570 |
+
@app.command()
|
571 |
+
def cache(
|
572 |
+
action: str = typer.Argument(help="Cache action: stats, clear"),
|
573 |
+
) -> None:
|
574 |
+
"""Manage cache operations."""
|
575 |
+
from gquery.utils.cache import get_cache_manager
|
576 |
+
|
577 |
+
cache_manager = get_cache_manager()
|
578 |
+
|
579 |
+
if action == "stats":
|
580 |
+
stats = cache_manager.get_stats()
|
581 |
+
|
582 |
+
table = Table(title="Cache Statistics")
|
583 |
+
table.add_column("Metric", style="cyan")
|
584 |
+
table.add_column("Value", style="green")
|
585 |
+
|
586 |
+
for key, value in stats.items():
|
587 |
+
table.add_row(str(key), str(value))
|
588 |
+
|
589 |
+
console.print(table)
|
590 |
+
|
591 |
+
elif action == "clear":
|
592 |
+
async def clear_cache():
|
593 |
+
await cache_manager.clear_all()
|
594 |
+
console.print("[bold green]✓ Cache cleared successfully![/bold green]")
|
595 |
+
|
596 |
+
asyncio.run(clear_cache())
|
597 |
+
|
598 |
+
else:
|
599 |
+
console.print(f"[bold red]Unknown action: {action}[/bold red]")
|
600 |
+
console.print("Available actions: stats, clear")
|
601 |
+
raise typer.Exit(1)
|
602 |
+
|
603 |
+
|
604 |
+
# Phase 2 Agent Commands
|
605 |
+
|
606 |
+
@app.command()
|
607 |
+
def query(
|
608 |
+
query_text: str = typer.Argument(..., help="Natural language query to process"),
|
609 |
+
synthesis: bool = typer.Option(True, help="Enable data synthesis"),
|
610 |
+
verbose: bool = typer.Option(False, help="Verbose output"),
|
611 |
+
output_format: str = typer.Option("table", help="Output format: table, json"),
|
612 |
+
) -> None:
|
613 |
+
"""Process a natural language query using AI agents."""
|
614 |
+
|
615 |
+
async def process_query():
|
616 |
+
from gquery.agents import orchestrate_query, AgentConfig
|
617 |
+
|
618 |
+
try:
|
619 |
+
console.print(f"[bold blue]Processing query:[/bold blue] {query_text}")
|
620 |
+
console.print("[dim]Using AI agents to analyze and orchestrate...[/dim]")
|
621 |
+
|
622 |
+
# Load configuration
|
623 |
+
config = AgentConfig.from_env()
|
624 |
+
|
625 |
+
# Orchestrate the query
|
626 |
+
result = await orchestrate_query(query_text, config)
|
627 |
+
|
628 |
+
if output_format == "json":
|
629 |
+
import json
|
630 |
+
from datetime import datetime
|
631 |
+
|
632 |
+
# Convert to JSON-serializable format
|
633 |
+
output = {
|
634 |
+
"query": result.query,
|
635 |
+
"success": result.success,
|
636 |
+
"execution_time_ms": result.execution_time_ms,
|
637 |
+
"analysis": {
|
638 |
+
"query_type": result.analysis.query_type.value if result.analysis else None,
|
639 |
+
"confidence": result.analysis.confidence if result.analysis else None,
|
640 |
+
"databases_needed": result.analysis.databases_needed if result.analysis else [],
|
641 |
+
"entity_count": len(result.analysis.entities) if result.analysis else 0
|
642 |
+
},
|
643 |
+
"database_results": {
|
644 |
+
db: bool(data) for db, data in result.database_results.items()
|
645 |
+
},
|
646 |
+
"synthesis_available": bool(result.synthesis),
|
647 |
+
"errors": result.errors
|
648 |
+
}
|
649 |
+
console.print(json.dumps(output, indent=2))
|
650 |
+
return
|
651 |
+
|
652 |
+
# Table output format
|
653 |
+
if result.success:
|
654 |
+
console.print(f"[bold green]✓ Query processed successfully![/bold green]")
|
655 |
+
|
656 |
+
# Analysis results
|
657 |
+
if result.analysis:
|
658 |
+
analysis_table = Table(title="Query Analysis")
|
659 |
+
analysis_table.add_column("Aspect", style="cyan")
|
660 |
+
analysis_table.add_column("Value", style="green")
|
661 |
+
|
662 |
+
analysis_table.add_row("Query Type", result.analysis.query_type.value)
|
663 |
+
analysis_table.add_row("Confidence", f"{result.analysis.confidence:.2f}")
|
664 |
+
analysis_table.add_row("Complexity", result.analysis.complexity)
|
665 |
+
analysis_table.add_row("Databases Used", ", ".join(result.analysis.databases_needed))
|
666 |
+
analysis_table.add_row("Entities Found", str(len(result.analysis.entities)))
|
667 |
+
|
668 |
+
console.print(analysis_table)
|
669 |
+
|
670 |
+
# Database results
|
671 |
+
if result.database_results:
|
672 |
+
db_table = Table(title="Database Results")
|
673 |
+
db_table.add_column("Database", style="cyan")
|
674 |
+
db_table.add_column("Status", style="green")
|
675 |
+
db_table.add_column("Records", style="yellow")
|
676 |
+
|
677 |
+
for db_name, data in result.database_results.items():
|
678 |
+
if data:
|
679 |
+
# Count records based on data structure
|
680 |
+
record_count = 0
|
681 |
+
if "gene_info" in data:
|
682 |
+
record_count = 1
|
683 |
+
elif "articles" in data:
|
684 |
+
record_count = len(data["articles"])
|
685 |
+
elif "variants" in data:
|
686 |
+
record_count = len(data["variants"])
|
687 |
+
|
688 |
+
db_table.add_row(db_name.upper(), "✓ Success", str(record_count))
|
689 |
+
else:
|
690 |
+
db_table.add_row(db_name.upper(), "⚠ No data", "0")
|
691 |
+
|
692 |
+
console.print(db_table)
|
693 |
+
|
694 |
+
# Synthesis results
|
695 |
+
if synthesis and result.synthesis:
|
696 |
+
console.print("\n[bold blue]Data Synthesis:[/bold blue]")
|
697 |
+
console.print(f"[bold]Executive Summary:[/bold]")
|
698 |
+
console.print(result.synthesis.get("executive_summary", "No summary available"))
|
699 |
+
|
700 |
+
if "key_findings" in result.synthesis and result.synthesis["key_findings"]:
|
701 |
+
console.print(f"\n[bold]Key Findings:[/bold]")
|
702 |
+
for i, finding in enumerate(result.synthesis["key_findings"], 1):
|
703 |
+
console.print(f"{i}. {finding}")
|
704 |
+
|
705 |
+
# Display source URLs
|
706 |
+
if "source_urls" in result.synthesis and result.synthesis["source_urls"]:
|
707 |
+
console.print(f"\n[bold]Source URLs:[/bold]")
|
708 |
+
for db_name, urls in result.synthesis["source_urls"].items():
|
709 |
+
console.print(f"\n[bold cyan]{db_name}:[/bold cyan]")
|
710 |
+
for url in urls[:5]: # Show first 5 URLs
|
711 |
+
console.print(f" • {url}")
|
712 |
+
if len(urls) > 5:
|
713 |
+
console.print(f" • ... and {len(urls) - 5} more URLs")
|
714 |
+
|
715 |
+
# Display data sources used
|
716 |
+
if "data_sources_used" in result.synthesis and result.synthesis["data_sources_used"]:
|
717 |
+
console.print(f"\n[bold]Data Sources Used:[/bold]")
|
718 |
+
for source in result.synthesis["data_sources_used"]:
|
719 |
+
console.print(f" • {source}")
|
720 |
+
|
721 |
+
# Processing time for synthesis
|
722 |
+
if "processing_time_ms" in result.synthesis:
|
723 |
+
console.print(f"\n[dim]Synthesis processing time: {result.synthesis['processing_time_ms']}ms[/dim]")
|
724 |
+
|
725 |
+
# Performance metrics
|
726 |
+
console.print(f"\n[dim]Execution time: {result.execution_time_ms}ms[/dim]")
|
727 |
+
|
728 |
+
else:
|
729 |
+
console.print(f"[bold red]✗ Query processing failed![/bold red]")
|
730 |
+
for error in result.errors:
|
731 |
+
console.print(f"[red]Error: {error}[/red]")
|
732 |
+
|
733 |
+
except Exception as e:
|
734 |
+
console.print(f"[bold red]Error processing query: {e}[/bold red]")
|
735 |
+
if verbose:
|
736 |
+
import traceback
|
737 |
+
console.print(traceback.format_exc())
|
738 |
+
|
739 |
+
asyncio.run(process_query())
|
740 |
+
|
741 |
+
|
742 |
+
@app.command()
|
743 |
+
def analyze(
|
744 |
+
query_text: str = typer.Argument(..., help="Query to analyze"),
|
745 |
+
verbose: bool = typer.Option(False, help="Verbose output"),
|
746 |
+
) -> None:
|
747 |
+
"""Analyze query intent and extract entities."""
|
748 |
+
|
749 |
+
async def analyze_query():
|
750 |
+
from gquery.agents import analyze_query_intent, AgentConfig
|
751 |
+
|
752 |
+
try:
|
753 |
+
console.print(f"[bold blue]Analyzing query:[/bold blue] {query_text}")
|
754 |
+
|
755 |
+
config = AgentConfig.from_env()
|
756 |
+
analysis = await analyze_query_intent(query_text, config)
|
757 |
+
|
758 |
+
# Display results
|
759 |
+
table = Table(title="Query Analysis Results")
|
760 |
+
table.add_column("Attribute", style="cyan")
|
761 |
+
table.add_column("Value", style="green")
|
762 |
+
|
763 |
+
table.add_row("Query Type", analysis.query_type.value)
|
764 |
+
table.add_row("Intent", analysis.intent)
|
765 |
+
table.add_row("Complexity", analysis.complexity)
|
766 |
+
table.add_row("Confidence", f"{analysis.confidence:.3f}")
|
767 |
+
table.add_row("Databases Needed", ", ".join(analysis.databases_needed))
|
768 |
+
table.add_row("Processing Time", f"{analysis.processing_time_ms}ms")
|
769 |
+
|
770 |
+
console.print(table)
|
771 |
+
|
772 |
+
# Show entities
|
773 |
+
if analysis.entities:
|
774 |
+
entity_table = Table(title="Extracted Entities")
|
775 |
+
entity_table.add_column("Name", style="yellow")
|
776 |
+
entity_table.add_column("Type", style="cyan")
|
777 |
+
entity_table.add_column("Confidence", style="green")
|
778 |
+
|
779 |
+
for entity in analysis.entities:
|
780 |
+
entity_table.add_row(
|
781 |
+
entity.name,
|
782 |
+
entity.entity_type,
|
783 |
+
f"{entity.confidence:.3f}"
|
784 |
+
)
|
785 |
+
|
786 |
+
console.print(entity_table)
|
787 |
+
|
788 |
+
except Exception as e:
|
789 |
+
console.print(f"[bold red]Analysis failed: {e}[/bold red]")
|
790 |
+
if verbose:
|
791 |
+
import traceback
|
792 |
+
console.print(traceback.format_exc())
|
793 |
+
|
794 |
+
asyncio.run(analyze_query())
|
795 |
+
|
796 |
+
|
797 |
+
@app.command()
|
798 |
+
def resolve(
|
799 |
+
entities: list[str] = typer.Argument(..., help="Entities to resolve"),
|
800 |
+
verbose: bool = typer.Option(False, help="Verbose output"),
|
801 |
+
) -> None:
|
802 |
+
"""Resolve biomedical entities to standard identifiers."""
|
803 |
+
|
804 |
+
async def resolve_entities():
|
805 |
+
from gquery.agents import resolve_biomedical_entities, AgentConfig
|
806 |
+
|
807 |
+
try:
|
808 |
+
console.print(f"[bold blue]Resolving entities:[/bold blue] {', '.join(entities)}")
|
809 |
+
|
810 |
+
config = AgentConfig.from_env()
|
811 |
+
result = await resolve_biomedical_entities(entities, config)
|
812 |
+
|
813 |
+
# Display resolution results
|
814 |
+
if result.resolved_entities:
|
815 |
+
resolved_table = Table(title="Resolved Entities")
|
816 |
+
resolved_table.add_column("Original", style="yellow")
|
817 |
+
resolved_table.add_column("Standardized", style="green")
|
818 |
+
resolved_table.add_column("Type", style="cyan")
|
819 |
+
resolved_table.add_column("Confidence", style="blue")
|
820 |
+
resolved_table.add_column("Identifiers", style="magenta")
|
821 |
+
|
822 |
+
for entity in result.resolved_entities:
|
823 |
+
identifiers = ", ".join([f"{id.database}:{id.identifier}" for id in entity.identifiers])
|
824 |
+
resolved_table.add_row(
|
825 |
+
entity.original_name,
|
826 |
+
entity.standardized_name,
|
827 |
+
entity.entity_type,
|
828 |
+
f"{entity.confidence:.3f}",
|
829 |
+
identifiers
|
830 |
+
)
|
831 |
+
|
832 |
+
console.print(resolved_table)
|
833 |
+
|
834 |
+
# Show unresolved entities
|
835 |
+
if result.unresolved_entities:
|
836 |
+
console.print(f"\n[bold yellow]Unresolved entities:[/bold yellow] {', '.join(result.unresolved_entities)}")
|
837 |
+
|
838 |
+
# Show summary
|
839 |
+
console.print(f"\n[dim]Resolution confidence: {result.resolution_confidence:.3f}[/dim]")
|
840 |
+
console.print(f"[dim]Processing time: {result.processing_time_ms}ms[/dim]")
|
841 |
+
|
842 |
+
except Exception as e:
|
843 |
+
console.print(f"[bold red]Entity resolution failed: {e}[/bold red]")
|
844 |
+
if verbose:
|
845 |
+
import traceback
|
846 |
+
console.print(traceback.format_exc())
|
847 |
+
|
848 |
+
asyncio.run(resolve_entities())
|
849 |
+
|
850 |
+
|
851 |
+
@app.command()
|
852 |
+
def synthesize(
|
853 |
+
datasets_file: Optional[str] = typer.Option(None, help="JSON file with datasets data"),
|
854 |
+
pmc_file: Optional[str] = typer.Option(None, help="JSON file with PMC data"),
|
855 |
+
clinvar_file: Optional[str] = typer.Option(None, help="JSON file with ClinVar data"),
|
856 |
+
query_text: str = typer.Option("Data synthesis", help="Context query for synthesis"),
|
857 |
+
verbose: bool = typer.Option(False, help="Verbose output"),
|
858 |
+
) -> None:
|
859 |
+
"""Synthesize data from multiple biomedical databases."""
|
860 |
+
|
861 |
+
async def synthesize_data():
|
862 |
+
from gquery.agents import synthesize_biomedical_data, AgentConfig
|
863 |
+
import json
|
864 |
+
|
865 |
+
try:
|
866 |
+
console.print("[bold blue]Synthesizing biomedical data...[/bold blue]")
|
867 |
+
|
868 |
+
# Load data files
|
869 |
+
datasets_data = None
|
870 |
+
pmc_data = None
|
871 |
+
clinvar_data = None
|
872 |
+
|
873 |
+
if datasets_file:
|
874 |
+
with open(datasets_file) as f:
|
875 |
+
datasets_data = json.load(f)
|
876 |
+
console.print(f"[dim]Loaded datasets data from {datasets_file}[/dim]")
|
877 |
+
|
878 |
+
if pmc_file:
|
879 |
+
with open(pmc_file) as f:
|
880 |
+
pmc_data = json.load(f)
|
881 |
+
console.print(f"[dim]Loaded PMC data from {pmc_file}[/dim]")
|
882 |
+
|
883 |
+
if clinvar_file:
|
884 |
+
with open(clinvar_file) as f:
|
885 |
+
clinvar_data = json.load(f)
|
886 |
+
console.print(f"[dim]Loaded ClinVar data from {clinvar_file}[/dim]")
|
887 |
+
|
888 |
+
if not any([datasets_data, pmc_data, clinvar_data]):
|
889 |
+
console.print("[bold red]No data files provided for synthesis![/bold red]")
|
890 |
+
console.print("Use --datasets-file, --pmc-file, or --clinvar-file options")
|
891 |
+
return
|
892 |
+
|
893 |
+
config = AgentConfig.from_env()
|
894 |
+
result = await synthesize_biomedical_data(
|
895 |
+
query_text, datasets_data, pmc_data, clinvar_data, config
|
896 |
+
)
|
897 |
+
|
898 |
+
# Display synthesis results
|
899 |
+
console.print(f"\n[bold green]Synthesis Results[/bold green]")
|
900 |
+
console.print(f"[bold]Executive Summary:[/bold]")
|
901 |
+
console.print(result.executive_summary)
|
902 |
+
|
903 |
+
if result.key_findings:
|
904 |
+
console.print(f"\n[bold]Key Findings:[/bold]")
|
905 |
+
for i, finding in enumerate(result.key_findings, 1):
|
906 |
+
console.print(f"{i}. {finding}")
|
907 |
+
|
908 |
+
if result.gaps_and_limitations:
|
909 |
+
console.print(f"\n[bold]Limitations and Gaps:[/bold]")
|
910 |
+
for gap in result.gaps_and_limitations:
|
911 |
+
console.print(f"• {gap}")
|
912 |
+
|
913 |
+
if result.recommendations:
|
914 |
+
console.print(f"\n[bold]Recommendations:[/bold]")
|
915 |
+
for rec in result.recommendations:
|
916 |
+
console.print(f"• {rec}")
|
917 |
+
|
918 |
+
# Data sources used
|
919 |
+
console.print(f"\n[dim]Data sources: {', '.join(result.data_sources_used)}[/dim]")
|
920 |
+
console.print(f"[dim]Processing time: {result.processing_time_ms}ms[/dim]")
|
921 |
+
|
922 |
+
except Exception as e:
|
923 |
+
console.print(f"[bold red]Synthesis failed: {e}[/bold red]")
|
924 |
+
if verbose:
|
925 |
+
import traceback
|
926 |
+
console.print(traceback.format_exc())
|
927 |
+
|
928 |
+
asyncio.run(synthesize_data())
|
929 |
+
|
930 |
+
|
931 |
+
@app.command()
|
932 |
+
def agent_health() -> None:
|
933 |
+
"""Check the health of AI agent components."""
|
934 |
+
|
935 |
+
async def check_agent_health():
|
936 |
+
from gquery.agents import AgentConfig
|
937 |
+
|
938 |
+
try:
|
939 |
+
console.print("[bold blue]Checking AI Agent Health...[/bold blue]")
|
940 |
+
|
941 |
+
config = AgentConfig.from_env()
|
942 |
+
|
943 |
+
health_table = Table(title="Agent Health Status")
|
944 |
+
health_table.add_column("Component", style="cyan")
|
945 |
+
health_table.add_column("Status", style="green")
|
946 |
+
health_table.add_column("Details", style="yellow")
|
947 |
+
|
948 |
+
# Check OpenAI API key
|
949 |
+
if config.openai_api_key:
|
950 |
+
health_table.add_row("OpenAI API Key", "✓ Configured", f"Model: {config.model}")
|
951 |
+
else:
|
952 |
+
health_table.add_row("OpenAI API Key", "✗ Missing", "Set OPENAI__API_KEY in .env")
|
953 |
+
|
954 |
+
# Check database clients
|
955 |
+
try:
|
956 |
+
from gquery.tools.datasets_client import DatasetsClient
|
957 |
+
datasets_client = DatasetsClient()
|
958 |
+
health_table.add_row("Datasets Client", "✓ Ready", "NCBI Datasets integration")
|
959 |
+
except Exception as e:
|
960 |
+
health_table.add_row("Datasets Client", "✗ Error", str(e))
|
961 |
+
|
962 |
+
try:
|
963 |
+
from gquery.tools.pmc_client import PMCClient
|
964 |
+
pmc_client = PMCClient()
|
965 |
+
health_table.add_row("PMC Client", "✓ Ready", "Literature search integration")
|
966 |
+
except Exception as e:
|
967 |
+
health_table.add_row("PMC Client", "✗ Error", str(e))
|
968 |
+
|
969 |
+
try:
|
970 |
+
from gquery.tools.clinvar_client import ClinVarClient
|
971 |
+
clinvar_client = ClinVarClient()
|
972 |
+
health_table.add_row("ClinVar Client", "✓ Ready", "Clinical variant integration")
|
973 |
+
except Exception as e:
|
974 |
+
health_table.add_row("ClinVar Client", "✗ Error", str(e))
|
975 |
+
|
976 |
+
# Test basic agent functionality
|
977 |
+
try:
|
978 |
+
from gquery.agents import QueryAnalyzer
|
979 |
+
analyzer = QueryAnalyzer(config)
|
980 |
+
health_table.add_row("Query Analyzer", "✓ Ready", f"Confidence threshold: {config.confidence_threshold}")
|
981 |
+
except Exception as e:
|
982 |
+
health_table.add_row("Query Analyzer", "✗ Error", str(e))
|
983 |
+
|
984 |
+
try:
|
985 |
+
from gquery.agents import DataSynthesizer
|
986 |
+
synthesizer = DataSynthesizer(config)
|
987 |
+
health_table.add_row("Data Synthesizer", "✓ Ready", f"Synthesis depth: {config.synthesis_depth}")
|
988 |
+
except Exception as e:
|
989 |
+
health_table.add_row("Data Synthesizer", "✗ Error", str(e))
|
990 |
+
|
991 |
+
try:
|
992 |
+
from gquery.agents import EntityResolver
|
993 |
+
resolver = EntityResolver(config)
|
994 |
+
health_table.add_row("Entity Resolver", "✓ Ready", "Biomedical entity resolution")
|
995 |
+
except Exception as e:
|
996 |
+
health_table.add_row("Entity Resolver", "✗ Error", str(e))
|
997 |
+
|
998 |
+
console.print(health_table)
|
999 |
+
|
1000 |
+
# Agent configuration summary
|
1001 |
+
config_table = Table(title="Agent Configuration")
|
1002 |
+
config_table.add_column("Setting", style="cyan")
|
1003 |
+
config_table.add_column("Value", style="green")
|
1004 |
+
|
1005 |
+
config_table.add_row("Model", config.model)
|
1006 |
+
config_table.add_row("Temperature", str(config.temperature))
|
1007 |
+
config_table.add_row("Max Tokens", str(config.max_tokens))
|
1008 |
+
config_table.add_row("Max Retries", str(config.max_retries))
|
1009 |
+
config_table.add_row("Confidence Threshold", str(config.confidence_threshold))
|
1010 |
+
config_table.add_row("Synthesis Depth", config.synthesis_depth)
|
1011 |
+
config_table.add_row("Concurrent Queries", str(config.concurrent_queries))
|
1012 |
+
|
1013 |
+
console.print(config_table)
|
1014 |
+
|
1015 |
+
except Exception as e:
|
1016 |
+
console.print(f"[bold red]Health check failed: {e}[/bold red]")
|
1017 |
+
|
1018 |
+
asyncio.run(check_agent_health())
|
1019 |
+
|
1020 |
+
|
1021 |
+
def main() -> None:
|
1022 |
+
"""Main CLI entry point."""
|
1023 |
+
app()
|
1024 |
+
|
1025 |
+
|
1026 |
+
if __name__ == "__main__":
|
1027 |
+
main()
|
gquery/src/gquery/config/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Core configuration management for GQuery AI.
|
3 |
+
|
4 |
+
This module handles all configuration loading, validation, and environment management
|
5 |
+
following the DEVELOPMENT_RULES.md specifications.
|
6 |
+
"""
|
gquery/src/gquery/config/__pycache__/__init__.cpython-310 2.pyc
ADDED
Binary file (387 Bytes). View file
|
|
gquery/src/gquery/config/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (387 Bytes). View file
|
|
gquery/src/gquery/config/__pycache__/settings.cpython-310 2.pyc
ADDED
Binary file (8.03 kB). View file
|
|
gquery/src/gquery/config/__pycache__/settings.cpython-310.pyc
ADDED
Binary file (8.03 kB). View file
|
|
gquery/src/gquery/config/settings.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Application settings and configuration management.
|
3 |
+
|
4 |
+
This module handles loading configuration from environment variables,
|
5 |
+
.env files, and provides typed configuration objects using Pydantic.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
from pydantic import Field, field_validator
|
13 |
+
from pydantic_settings import BaseSettings
|
14 |
+
|
15 |
+
|
16 |
+
class DatabaseSettings(BaseSettings):
|
17 |
+
"""Database configuration settings."""
|
18 |
+
|
19 |
+
host: str = Field(default="localhost", description="Database host")
|
20 |
+
port: int = Field(default=5432, description="Database port")
|
21 |
+
name: str = Field(default="gquery", description="Database name")
|
22 |
+
user: str = Field(default="postgres", description="Database user")
|
23 |
+
password: str = Field(default="", description="Database password")
|
24 |
+
|
25 |
+
model_config = {"env_prefix": "DATABASE__"}
|
26 |
+
|
27 |
+
@property
|
28 |
+
def url(self) -> str:
|
29 |
+
"""Generate database URL."""
|
30 |
+
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.name}"
|
31 |
+
|
32 |
+
|
33 |
+
class RedisSettings(BaseSettings):
|
34 |
+
"""Redis configuration settings."""
|
35 |
+
|
36 |
+
host: str = Field(default="localhost", description="Redis host")
|
37 |
+
port: int = Field(default=6379, description="Redis port")
|
38 |
+
db: int = Field(default=0, description="Redis database number")
|
39 |
+
password: Optional[str] = Field(default=None, description="Redis password")
|
40 |
+
|
41 |
+
model_config = {"env_prefix": "REDIS__"}
|
42 |
+
|
43 |
+
@property
|
44 |
+
def url(self) -> str:
|
45 |
+
"""Generate Redis URL."""
|
46 |
+
auth = f":{self.password}@" if self.password else ""
|
47 |
+
return f"redis://{auth}{self.host}:{self.port}/{self.db}"
|
48 |
+
|
49 |
+
|
50 |
+
class NCBISettings(BaseSettings):
|
51 |
+
"""NCBI API configuration settings."""
|
52 |
+
|
53 |
+
api_key: Optional[str] = Field(default=None, description="NCBI API key")
|
54 |
+
email: str = Field(default="[email protected]", description="Email for NCBI API")
|
55 |
+
base_url: str = Field(default="https://eutils.ncbi.nlm.nih.gov", description="NCBI base URL")
|
56 |
+
rate_limit: float = Field(default=3.0, description="Requests per second")
|
57 |
+
timeout: int = Field(default=30, description="Request timeout in seconds")
|
58 |
+
|
59 |
+
model_config = {"env_prefix": "NCBI__"}
|
60 |
+
|
61 |
+
@field_validator("email")
|
62 |
+
@classmethod
|
63 |
+
def validate_email(cls, v):
|
64 |
+
"""Validate email format."""
|
65 |
+
if "@" not in v:
|
66 |
+
raise ValueError("Invalid email format")
|
67 |
+
return v
|
68 |
+
|
69 |
+
|
70 |
+
class OpenAISettings(BaseSettings):
|
71 |
+
"""OpenAI API configuration settings."""
|
72 |
+
|
73 |
+
api_key: str = Field(default="sk-test-key-replace-in-production", description="OpenAI API key")
|
74 |
+
model: str = Field(default="gpt-4", description="Default OpenAI model")
|
75 |
+
temperature: float = Field(default=0.1, description="Model temperature")
|
76 |
+
max_tokens: int = Field(default=4000, description="Maximum tokens per request")
|
77 |
+
timeout: int = Field(default=60, description="Request timeout in seconds")
|
78 |
+
|
79 |
+
model_config = {"env_prefix": "OPENAI__"}
|
80 |
+
|
81 |
+
|
82 |
+
class LoggingSettings(BaseSettings):
|
83 |
+
"""Logging configuration settings."""
|
84 |
+
|
85 |
+
level: str = Field(default="INFO", description="Log level")
|
86 |
+
format: str = Field(default="json", description="Log format (json|text)")
|
87 |
+
file_enabled: bool = Field(default=True, description="Enable file logging")
|
88 |
+
file_path: str = Field(default="logs/gquery.log", description="Log file path")
|
89 |
+
console_enabled: bool = Field(default=True, description="Enable console logging")
|
90 |
+
|
91 |
+
model_config = {"env_prefix": "LOGGING__"}
|
92 |
+
|
93 |
+
@field_validator("level")
|
94 |
+
@classmethod
|
95 |
+
def validate_level(cls, v):
|
96 |
+
"""Validate log level."""
|
97 |
+
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
98 |
+
if v.upper() not in valid_levels:
|
99 |
+
raise ValueError(f"Invalid log level. Must be one of: {valid_levels}")
|
100 |
+
return v.upper()
|
101 |
+
|
102 |
+
|
103 |
+
class SecuritySettings(BaseSettings):
|
104 |
+
"""Security configuration settings."""
|
105 |
+
|
106 |
+
secret_key: str = Field(default="dev-secret-key-change-in-production", description="Secret key for JWT tokens")
|
107 |
+
algorithm: str = Field(default="HS256", description="JWT algorithm")
|
108 |
+
access_token_expire_minutes: int = Field(default=30, description="Access token expiry")
|
109 |
+
cors_origins: List[str] = Field(default=["http://localhost:3000"], description="CORS origins")
|
110 |
+
|
111 |
+
@field_validator("cors_origins", mode="before")
|
112 |
+
@classmethod
|
113 |
+
def parse_cors_origins(cls, v):
|
114 |
+
"""Parse CORS origins from comma-separated string or list."""
|
115 |
+
if isinstance(v, str):
|
116 |
+
return [origin.strip() for origin in v.split(",") if origin.strip()]
|
117 |
+
return v
|
118 |
+
|
119 |
+
model_config = {"env_prefix": "SECURITY__"}
|
120 |
+
|
121 |
+
|
122 |
+
class Settings(BaseSettings):
|
123 |
+
"""Main application settings."""
|
124 |
+
|
125 |
+
# Application
|
126 |
+
app_name: str = Field(default="GQuery AI", description="Application name")
|
127 |
+
version: str = Field(default="0.1.0", description="Application version")
|
128 |
+
debug: bool = Field(default=False, description="Debug mode")
|
129 |
+
environment: str = Field(default="development", description="Environment")
|
130 |
+
|
131 |
+
# API
|
132 |
+
host: str = Field(default="0.0.0.0", description="API host")
|
133 |
+
port: int = Field(default=8000, description="API port")
|
134 |
+
workers: int = Field(default=1, description="Number of workers")
|
135 |
+
|
136 |
+
# Component settings
|
137 |
+
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
138 |
+
redis: RedisSettings = Field(default_factory=RedisSettings)
|
139 |
+
ncbi: NCBISettings = Field(default_factory=NCBISettings)
|
140 |
+
openai: OpenAISettings = Field(default_factory=OpenAISettings)
|
141 |
+
logging: LoggingSettings = Field(default_factory=LoggingSettings)
|
142 |
+
security: SecuritySettings = Field(default_factory=SecuritySettings)
|
143 |
+
|
144 |
+
# Compatibility properties for flat access (for backwards compatibility with agents)
|
145 |
+
@property
|
146 |
+
def openai_api_key(self) -> str:
|
147 |
+
"""Get OpenAI API key from nested settings."""
|
148 |
+
return self.openai.api_key
|
149 |
+
|
150 |
+
@property
|
151 |
+
def ncbi_api_key(self) -> str:
|
152 |
+
"""Get NCBI API key from nested settings."""
|
153 |
+
return self.ncbi.api_key
|
154 |
+
|
155 |
+
@property
|
156 |
+
def ncbi_email(self) -> str:
|
157 |
+
"""Get NCBI email from nested settings."""
|
158 |
+
return self.ncbi.email
|
159 |
+
|
160 |
+
@property
|
161 |
+
def model(self) -> str:
|
162 |
+
"""Get OpenAI model from nested settings."""
|
163 |
+
return self.openai.model
|
164 |
+
|
165 |
+
@property
|
166 |
+
def temperature(self) -> float:
|
167 |
+
"""Get OpenAI temperature from nested settings."""
|
168 |
+
return self.openai.temperature
|
169 |
+
|
170 |
+
@property
|
171 |
+
def max_tokens(self) -> int:
|
172 |
+
"""Get OpenAI max_tokens from nested settings."""
|
173 |
+
return self.openai.max_tokens
|
174 |
+
|
175 |
+
model_config = {
|
176 |
+
"env_file": ".env",
|
177 |
+
"env_file_encoding": "utf-8",
|
178 |
+
"env_nested_delimiter": "__",
|
179 |
+
"case_sensitive": False,
|
180 |
+
"extra": "ignore"
|
181 |
+
}
|
182 |
+
|
183 |
+
|
184 |
+
# Global settings instance
|
185 |
+
_settings: Optional[Settings] = None
|
186 |
+
|
187 |
+
|
188 |
+
def get_settings() -> Settings:
|
189 |
+
"""Get application settings singleton."""
|
190 |
+
global _settings
|
191 |
+
if _settings is None:
|
192 |
+
_settings = Settings()
|
193 |
+
return _settings
|
194 |
+
|
195 |
+
|
196 |
+
def reload_settings() -> Settings:
|
197 |
+
"""Reload settings (useful for testing)."""
|
198 |
+
global _settings
|
199 |
+
_settings = None
|
200 |
+
return get_settings()
|
gquery/src/gquery/interfaces/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Abstract base classes and protocols for GQuery AI.
|
3 |
+
|
4 |
+
This module defines interfaces and contracts between components
|
5 |
+
to ensure loose coupling and maintainability.
|
6 |
+
"""
|
gquery/src/gquery/models/__init__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data models for GQuery AI.
|
3 |
+
|
4 |
+
This package contains all Pydantic models used throughout the application
|
5 |
+
for data validation, serialization, and API responses.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from gquery.models.base import (
|
9 |
+
APIResponse,
|
10 |
+
BaseModel,
|
11 |
+
ErrorDetail,
|
12 |
+
HealthCheck,
|
13 |
+
PaginatedResponse,
|
14 |
+
ValidationError,
|
15 |
+
)
|
16 |
+
from gquery.models.pmc import (
|
17 |
+
PMCArticle,
|
18 |
+
PMCArticleMetadata,
|
19 |
+
PMCSearchFilters,
|
20 |
+
PMCSearchResponse,
|
21 |
+
PMCSearchResult,
|
22 |
+
VariantMention,
|
23 |
+
)
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
# Base models
|
27 |
+
"BaseModel",
|
28 |
+
"APIResponse",
|
29 |
+
"PaginatedResponse",
|
30 |
+
"HealthCheck",
|
31 |
+
"ErrorDetail",
|
32 |
+
"ValidationError",
|
33 |
+
# PMC models
|
34 |
+
"PMCArticle",
|
35 |
+
"PMCArticleMetadata",
|
36 |
+
"PMCSearchFilters",
|
37 |
+
"PMCSearchResponse",
|
38 |
+
"PMCSearchResult",
|
39 |
+
"VariantMention",
|
40 |
+
]
|
gquery/src/gquery/models/__pycache__/__init__.cpython-310 2.pyc
ADDED
Binary file (806 Bytes). View file
|
|
gquery/src/gquery/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (806 Bytes). View file
|
|
gquery/src/gquery/models/__pycache__/base.cpython-310 2.pyc
ADDED
Binary file (4.18 kB). View file
|
|
gquery/src/gquery/models/__pycache__/base.cpython-310.pyc
ADDED
Binary file (4.18 kB). View file
|
|
gquery/src/gquery/models/__pycache__/clinvar.cpython-310 2.pyc
ADDED
Binary file (13 kB). View file
|
|
gquery/src/gquery/models/__pycache__/clinvar.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
gquery/src/gquery/models/__pycache__/datasets.cpython-310 2.pyc
ADDED
Binary file (20.7 kB). View file
|
|
gquery/src/gquery/models/__pycache__/datasets.cpython-310.pyc
ADDED
Binary file (20.7 kB). View file
|
|
gquery/src/gquery/models/__pycache__/pmc.cpython-310 2.pyc
ADDED
Binary file (11.6 kB). View file
|
|
gquery/src/gquery/models/__pycache__/pmc.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
gquery/src/gquery/models/base.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base data models for GQuery AI.
|
3 |
+
|
4 |
+
This module provides base Pydantic models and common schemas
|
5 |
+
used throughout the application.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from datetime import datetime, timezone
|
9 |
+
from typing import Any, Dict, List, Optional
|
10 |
+
from uuid import UUID, uuid4
|
11 |
+
|
12 |
+
from pydantic import BaseModel as PydanticBaseModel, Field, ConfigDict
|
13 |
+
|
14 |
+
|
15 |
+
class BaseModel(PydanticBaseModel):
|
16 |
+
"""
|
17 |
+
Base model for all GQuery AI data models.
|
18 |
+
|
19 |
+
Provides common functionality like ID generation, timestamps,
|
20 |
+
and serialization methods.
|
21 |
+
"""
|
22 |
+
|
23 |
+
model_config = ConfigDict(
|
24 |
+
use_enum_values=True,
|
25 |
+
validate_assignment=True,
|
26 |
+
arbitrary_types_allowed=True,
|
27 |
+
)
|
28 |
+
|
29 |
+
id: UUID = Field(default_factory=uuid4, description="Unique identifier")
|
30 |
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Creation timestamp")
|
31 |
+
updated_at: Optional[datetime] = Field(default=None, description="Last update timestamp")
|
32 |
+
|
33 |
+
def update_timestamp(self) -> None:
|
34 |
+
"""Update the updated_at timestamp."""
|
35 |
+
self.updated_at = datetime.now(timezone.utc)
|
36 |
+
|
37 |
+
|
38 |
+
class APIResponse(BaseModel):
|
39 |
+
"""Standard API response wrapper."""
|
40 |
+
|
41 |
+
success: bool = Field(description="Whether the request was successful")
|
42 |
+
message: str = Field(description="Response message")
|
43 |
+
data: Optional[Any] = Field(default=None, description="Response data")
|
44 |
+
errors: List[str] = Field(default_factory=list, description="Error messages")
|
45 |
+
meta: Dict[str, Any] = Field(default_factory=dict, description="Response metadata")
|
46 |
+
|
47 |
+
|
48 |
+
class PaginatedResponse(APIResponse):
|
49 |
+
"""Paginated API response."""
|
50 |
+
|
51 |
+
page: int = Field(ge=1, description="Current page number")
|
52 |
+
page_size: int = Field(ge=1, description="Number of items per page")
|
53 |
+
total_items: int = Field(ge=0, description="Total number of items")
|
54 |
+
total_pages: int = Field(ge=0, description="Total number of pages")
|
55 |
+
has_next: bool = Field(description="Whether there is a next page")
|
56 |
+
has_previous: bool = Field(description="Whether there is a previous page")
|
57 |
+
|
58 |
+
|
59 |
+
class HealthCheck(BaseModel):
|
60 |
+
"""Health check response model."""
|
61 |
+
|
62 |
+
status: str = Field(description="Service status")
|
63 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
64 |
+
version: str = Field(description="Application version")
|
65 |
+
uptime: float = Field(description="Uptime in seconds")
|
66 |
+
checks: Dict[str, bool] = Field(description="Component health checks")
|
67 |
+
|
68 |
+
|
69 |
+
class ErrorDetail(BaseModel):
|
70 |
+
"""Detailed error information."""
|
71 |
+
|
72 |
+
code: str = Field(description="Error code")
|
73 |
+
message: str = Field(description="Error message")
|
74 |
+
field: Optional[str] = Field(default=None, description="Field that caused the error")
|
75 |
+
context: Dict[str, Any] = Field(default_factory=dict, description="Additional error context")
|
76 |
+
|
77 |
+
|
78 |
+
class ValidationError(BaseModel):
|
79 |
+
"""Validation error response."""
|
80 |
+
|
81 |
+
message: str = Field(description="Validation error message")
|
82 |
+
errors: List[ErrorDetail] = Field(description="Detailed validation errors")
|
83 |
+
|
84 |
+
|
85 |
+
# Type aliases for common patterns
|
86 |
+
ID = UUID
|
87 |
+
Timestamp = datetime
|
88 |
+
JSONData = Dict[str, Any]
|
89 |
+
QueryParams = Dict[str, Any]
|
gquery/src/gquery/models/clinvar.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ClinVar data models for GQuery AI.
|
3 |
+
|
4 |
+
This module defines Pydantic models for ClinVar variants, clinical significance,
|
5 |
+
and API responses used throughout the application.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from datetime import datetime
|
9 |
+
from enum import Enum
|
10 |
+
from typing import Any, Dict, List, Optional
|
11 |
+
import re
|
12 |
+
|
13 |
+
from pydantic import Field, field_validator
|
14 |
+
|
15 |
+
from gquery.models.base import BaseModel
|
16 |
+
|
17 |
+
|
18 |
+
class ClinicalSignificance(str, Enum):
|
19 |
+
"""Clinical significance classification for variants."""
|
20 |
+
|
21 |
+
PATHOGENIC = "Pathogenic"
|
22 |
+
LIKELY_PATHOGENIC = "Likely pathogenic"
|
23 |
+
UNCERTAIN_SIGNIFICANCE = "Uncertain significance"
|
24 |
+
LIKELY_BENIGN = "Likely benign"
|
25 |
+
BENIGN = "Benign"
|
26 |
+
CONFLICTING = "Conflicting interpretations of pathogenicity"
|
27 |
+
NOT_PROVIDED = "not provided"
|
28 |
+
OTHER = "other"
|
29 |
+
|
30 |
+
|
31 |
+
class ReviewStatus(str, Enum):
|
32 |
+
"""Review status for ClinVar submissions."""
|
33 |
+
|
34 |
+
PRACTICE_GUIDELINE = "practice guideline"
|
35 |
+
REVIEWED_BY_EXPERT_PANEL = "reviewed by expert panel"
|
36 |
+
CRITERIA_PROVIDED_MULTIPLE_SUBMITTERS = "criteria provided, multiple submitters, no conflicts"
|
37 |
+
CRITERIA_PROVIDED_CONFLICTING = "criteria provided, conflicting interpretations"
|
38 |
+
CRITERIA_PROVIDED_SINGLE_SUBMITTER = "criteria provided, single submitter"
|
39 |
+
NO_ASSERTION_CRITERIA = "no assertion criteria provided"
|
40 |
+
NO_ASSERTION_PROVIDED = "no assertion provided"
|
41 |
+
|
42 |
+
|
43 |
+
class VariationType(str, Enum):
|
44 |
+
"""Type of genetic variation."""
|
45 |
+
|
46 |
+
SNV = "single nucleotide variant"
|
47 |
+
DELETION = "Deletion"
|
48 |
+
DUPLICATION = "Duplication"
|
49 |
+
INSERTION = "Insertion"
|
50 |
+
INDEL = "Indel"
|
51 |
+
INVERSION = "Inversion"
|
52 |
+
CNV = "copy number variation"
|
53 |
+
STRUCTURAL_VARIANT = "structural variant"
|
54 |
+
COMPLEX = "complex"
|
55 |
+
OTHER = "other"
|
56 |
+
|
57 |
+
|
58 |
+
class ClinVarSubmission(BaseModel):
|
59 |
+
"""Individual submission to ClinVar."""
|
60 |
+
|
61 |
+
submitter: str = Field(description="Submitter organization")
|
62 |
+
submission_date: Optional[datetime] = Field(default=None, description="Date of submission")
|
63 |
+
clinical_significance: ClinicalSignificance = Field(description="Reported clinical significance")
|
64 |
+
review_status: ReviewStatus = Field(description="Review status of submission")
|
65 |
+
assertion_method: Optional[str] = Field(default=None, description="Method used for assertion")
|
66 |
+
description: Optional[str] = Field(default=None, description="Submission description")
|
67 |
+
|
68 |
+
@field_validator("submission_date", mode="before")
|
69 |
+
@classmethod
|
70 |
+
def parse_submission_date(cls, v: Any) -> Optional[datetime]:
|
71 |
+
"""Parse submission date from various formats."""
|
72 |
+
if v is None or v == "":
|
73 |
+
return None
|
74 |
+
|
75 |
+
if isinstance(v, datetime):
|
76 |
+
return v
|
77 |
+
|
78 |
+
if isinstance(v, str):
|
79 |
+
# Handle various date formats from ClinVar
|
80 |
+
date_patterns = [
|
81 |
+
r"(\d{4})-(\d{1,2})-(\d{1,2})", # "2016-11-02"
|
82 |
+
r"(\d{4})/(\d{1,2})/(\d{1,2})", # "2016/11/02"
|
83 |
+
r"(\d{1,2})/(\d{1,2})/(\d{4})", # "11/02/2016"
|
84 |
+
]
|
85 |
+
|
86 |
+
for pattern in date_patterns:
|
87 |
+
match = re.match(pattern, v.strip())
|
88 |
+
if match:
|
89 |
+
try:
|
90 |
+
if pattern.startswith(r"(\d{4})"): # Year first
|
91 |
+
year, month, day = match.groups()
|
92 |
+
else: # Month/day first
|
93 |
+
month, day, year = match.groups()
|
94 |
+
|
95 |
+
return datetime(int(year), int(month), int(day))
|
96 |
+
except (ValueError, TypeError):
|
97 |
+
continue
|
98 |
+
|
99 |
+
# Try ISO format
|
100 |
+
try:
|
101 |
+
return datetime.fromisoformat(v.replace("Z", "+00:00"))
|
102 |
+
except ValueError:
|
103 |
+
pass
|
104 |
+
|
105 |
+
return None
|
106 |
+
|
107 |
+
|
108 |
+
class ClinVarVariant(BaseModel):
|
109 |
+
"""
|
110 |
+
ClinVar variant model.
|
111 |
+
|
112 |
+
Represents a genetic variant with clinical significance
|
113 |
+
and submission information from ClinVar.
|
114 |
+
"""
|
115 |
+
|
116 |
+
variation_id: str = Field(description="ClinVar Variation ID")
|
117 |
+
name: str = Field(description="Variant name/description")
|
118 |
+
|
119 |
+
# Genomic coordinates
|
120 |
+
gene_symbol: Optional[str] = Field(default=None, description="Associated gene symbol")
|
121 |
+
chromosome: Optional[str] = Field(default=None, description="Chromosome")
|
122 |
+
start_position: Optional[int] = Field(default=None, description="Start position")
|
123 |
+
stop_position: Optional[int] = Field(default=None, description="Stop position")
|
124 |
+
reference_allele: Optional[str] = Field(default=None, description="Reference allele")
|
125 |
+
alternate_allele: Optional[str] = Field(default=None, description="Alternate allele")
|
126 |
+
|
127 |
+
# Variant classification
|
128 |
+
variation_type: Optional[VariationType] = Field(default=None, description="Type of variation")
|
129 |
+
clinical_significance: ClinicalSignificance = Field(description="Overall clinical significance")
|
130 |
+
review_status: ReviewStatus = Field(description="Overall review status")
|
131 |
+
|
132 |
+
# HGVS nomenclature
|
133 |
+
hgvs_genomic: Optional[str] = Field(default=None, description="HGVS genomic notation")
|
134 |
+
hgvs_coding: Optional[str] = Field(default=None, description="HGVS coding notation")
|
135 |
+
hgvs_protein: Optional[str] = Field(default=None, description="HGVS protein notation")
|
136 |
+
|
137 |
+
# Submissions and evidence
|
138 |
+
submissions: List[ClinVarSubmission] = Field(default_factory=list, description="Individual submissions")
|
139 |
+
number_of_submissions: int = Field(default=0, ge=0, description="Total number of submissions")
|
140 |
+
|
141 |
+
# Cross-references
|
142 |
+
rs_id: Optional[str] = Field(default=None, description="dbSNP rs ID")
|
143 |
+
allele_id: Optional[str] = Field(default=None, description="ClinVar Allele ID")
|
144 |
+
|
145 |
+
# Metadata
|
146 |
+
last_evaluated: Optional[datetime] = Field(default=None, description="Date of last evaluation")
|
147 |
+
created_date: Optional[datetime] = Field(default=None, description="Date variant was created in ClinVar")
|
148 |
+
updated_date: Optional[datetime] = Field(default=None, description="Date variant was last updated")
|
149 |
+
|
150 |
+
# Quality metrics
|
151 |
+
star_rating: int = Field(default=0, ge=0, le=4, description="ClinVar star rating (0-4)")
|
152 |
+
confidence_score: float = Field(default=1.0, ge=0.0, le=1.0, description="Confidence in classification")
|
153 |
+
|
154 |
+
@field_validator("variation_id")
|
155 |
+
@classmethod
|
156 |
+
def validate_variation_id(cls, v: str) -> str:
|
157 |
+
"""Validate ClinVar Variation ID format."""
|
158 |
+
if not v.isdigit():
|
159 |
+
raise ValueError("ClinVar Variation ID must be numeric")
|
160 |
+
return v
|
161 |
+
|
162 |
+
@field_validator("rs_id")
|
163 |
+
@classmethod
|
164 |
+
def validate_rs_id(cls, v: Optional[str]) -> Optional[str]:
|
165 |
+
"""Validate dbSNP rs ID format."""
|
166 |
+
if v is not None and v != "" and not v.startswith("rs"):
|
167 |
+
raise ValueError("dbSNP ID must start with 'rs'")
|
168 |
+
return v
|
169 |
+
|
170 |
+
@field_validator("hgvs_genomic", "hgvs_coding", "hgvs_protein")
|
171 |
+
@classmethod
|
172 |
+
def validate_hgvs_format(cls, v: Optional[str]) -> Optional[str]:
|
173 |
+
"""Basic HGVS format validation."""
|
174 |
+
if v is not None and v != "":
|
175 |
+
# Basic HGVS format check
|
176 |
+
if not any(pattern in v for pattern in ["c.", "p.", "g.", "n.", "r.", "NM_", "NP_", "NC_", "NR_"]):
|
177 |
+
# Allow for simple descriptions without strict HGVS format
|
178 |
+
pass
|
179 |
+
return v
|
180 |
+
|
181 |
+
@property
|
182 |
+
def clinvar_url(self) -> str:
|
183 |
+
"""Generate ClinVar URL for this variant."""
|
184 |
+
return f"https://www.ncbi.nlm.nih.gov/clinvar/variation/{self.variation_id}/"
|
185 |
+
|
186 |
+
@property
|
187 |
+
def dbsnp_url(self) -> Optional[str]:
|
188 |
+
"""Generate dbSNP URL if rs ID is available."""
|
189 |
+
if self.rs_id:
|
190 |
+
return f"https://www.ncbi.nlm.nih.gov/snp/{self.rs_id}"
|
191 |
+
return None
|
192 |
+
|
193 |
+
@property
|
194 |
+
def is_pathogenic(self) -> bool:
|
195 |
+
"""Check if variant is considered pathogenic."""
|
196 |
+
return self.clinical_significance in [
|
197 |
+
ClinicalSignificance.PATHOGENIC,
|
198 |
+
ClinicalSignificance.LIKELY_PATHOGENIC
|
199 |
+
]
|
200 |
+
|
201 |
+
@property
|
202 |
+
def is_benign(self) -> bool:
|
203 |
+
"""Check if variant is considered benign."""
|
204 |
+
return self.clinical_significance in [
|
205 |
+
ClinicalSignificance.BENIGN,
|
206 |
+
ClinicalSignificance.LIKELY_BENIGN
|
207 |
+
]
|
208 |
+
|
209 |
+
@property
|
210 |
+
def has_conflicting_evidence(self) -> bool:
|
211 |
+
"""Check if variant has conflicting evidence."""
|
212 |
+
return self.clinical_significance == ClinicalSignificance.CONFLICTING
|
213 |
+
|
214 |
+
|
215 |
+
class ClinVarSearchFilters(BaseModel):
|
216 |
+
"""
|
217 |
+
Search filters for ClinVar API queries.
|
218 |
+
|
219 |
+
Provides structured filtering options for ClinVar variant searches.
|
220 |
+
"""
|
221 |
+
|
222 |
+
# Gene filters
|
223 |
+
gene_symbols: List[str] = Field(default_factory=list, description="Filter by gene symbols")
|
224 |
+
|
225 |
+
# Clinical significance filters
|
226 |
+
clinical_significance: List[ClinicalSignificance] = Field(
|
227 |
+
default_factory=list,
|
228 |
+
description="Filter by clinical significance"
|
229 |
+
)
|
230 |
+
review_status: List[ReviewStatus] = Field(
|
231 |
+
default_factory=list,
|
232 |
+
description="Filter by review status"
|
233 |
+
)
|
234 |
+
|
235 |
+
# Variant type filters
|
236 |
+
variation_types: List[VariationType] = Field(
|
237 |
+
default_factory=list,
|
238 |
+
description="Filter by variation types"
|
239 |
+
)
|
240 |
+
|
241 |
+
# Quality filters
|
242 |
+
min_star_rating: int = Field(default=0, ge=0, le=4, description="Minimum star rating")
|
243 |
+
min_submissions: int = Field(default=0, ge=0, description="Minimum number of submissions")
|
244 |
+
|
245 |
+
# Date filters
|
246 |
+
date_from: Optional[datetime] = Field(default=None, description="Variants updated after this date")
|
247 |
+
date_to: Optional[datetime] = Field(default=None, description="Variants updated before this date")
|
248 |
+
|
249 |
+
# Genomic location filters
|
250 |
+
chromosome: Optional[str] = Field(default=None, description="Filter by chromosome")
|
251 |
+
position_start: Optional[int] = Field(default=None, description="Start position for range")
|
252 |
+
position_end: Optional[int] = Field(default=None, description="End position for range")
|
253 |
+
|
254 |
+
def to_query_params(self) -> Dict[str, Any]:
|
255 |
+
"""Convert filters to query parameters for API calls."""
|
256 |
+
params = {}
|
257 |
+
|
258 |
+
if self.gene_symbols:
|
259 |
+
params["gene_symbols"] = ",".join(self.gene_symbols)
|
260 |
+
if self.clinical_significance:
|
261 |
+
# Handle both enum objects and string values
|
262 |
+
significance_values = []
|
263 |
+
for cs in self.clinical_significance:
|
264 |
+
if hasattr(cs, 'value'):
|
265 |
+
significance_values.append(cs.value)
|
266 |
+
else:
|
267 |
+
significance_values.append(str(cs))
|
268 |
+
params["clinical_significance"] = ",".join(significance_values)
|
269 |
+
if self.review_status:
|
270 |
+
# Handle both enum objects and string values
|
271 |
+
status_values = []
|
272 |
+
for rs in self.review_status:
|
273 |
+
if hasattr(rs, 'value'):
|
274 |
+
status_values.append(rs.value)
|
275 |
+
else:
|
276 |
+
status_values.append(str(rs))
|
277 |
+
params["review_status"] = ",".join(status_values)
|
278 |
+
if self.variation_types:
|
279 |
+
# Handle both enum objects and string values
|
280 |
+
type_values = []
|
281 |
+
for vt in self.variation_types:
|
282 |
+
if hasattr(vt, 'value'):
|
283 |
+
type_values.append(vt.value)
|
284 |
+
else:
|
285 |
+
type_values.append(str(vt))
|
286 |
+
params["variation_types"] = ",".join(type_values)
|
287 |
+
if self.min_star_rating > 0:
|
288 |
+
params["min_star_rating"] = str(self.min_star_rating)
|
289 |
+
if self.min_submissions > 0:
|
290 |
+
params["min_submissions"] = str(self.min_submissions)
|
291 |
+
if self.date_from:
|
292 |
+
params["date_from"] = self.date_from.strftime("%Y/%m/%d")
|
293 |
+
if self.date_to:
|
294 |
+
params["date_to"] = self.date_to.strftime("%Y/%m/%d")
|
295 |
+
if self.chromosome:
|
296 |
+
params["chromosome"] = self.chromosome
|
297 |
+
if self.position_start:
|
298 |
+
params["position_start"] = str(self.position_start)
|
299 |
+
if self.position_end:
|
300 |
+
params["position_end"] = str(self.position_end)
|
301 |
+
|
302 |
+
return params
|
303 |
+
|
304 |
+
|
305 |
+
class ClinVarSearchResult(BaseModel):
|
306 |
+
"""
|
307 |
+
ClinVar search result with metadata and relevance information.
|
308 |
+
|
309 |
+
Represents a single search result with scoring and metadata
|
310 |
+
for efficient result processing.
|
311 |
+
"""
|
312 |
+
|
313 |
+
variant: ClinVarVariant = Field(description="Variant information")
|
314 |
+
relevance_score: float = Field(ge=0.0, le=1.0, description="Query relevance score")
|
315 |
+
match_highlights: List[str] = Field(default_factory=list, description="Text highlights showing matches")
|
316 |
+
|
317 |
+
# Search context
|
318 |
+
query_terms: List[str] = Field(default_factory=list, description="Query terms that matched")
|
319 |
+
search_filters: Optional[ClinVarSearchFilters] = Field(default=None, description="Applied search filters")
|
320 |
+
|
321 |
+
|
322 |
+
class ClinVarSearchResponse(BaseModel):
|
323 |
+
"""
|
324 |
+
Complete ClinVar search response.
|
325 |
+
|
326 |
+
Contains search results, pagination information, and metadata
|
327 |
+
for a ClinVar search operation.
|
328 |
+
"""
|
329 |
+
|
330 |
+
query: str = Field(description="Original search query")
|
331 |
+
total_count: int = Field(ge=0, description="Total number of matching variants")
|
332 |
+
results: List[ClinVarSearchResult] = Field(default_factory=list, description="Search results")
|
333 |
+
|
334 |
+
# Pagination
|
335 |
+
page: int = Field(ge=1, description="Current page number")
|
336 |
+
page_size: int = Field(ge=1, description="Number of results per page")
|
337 |
+
total_pages: int = Field(ge=0, description="Total number of pages")
|
338 |
+
|
339 |
+
# Search metadata
|
340 |
+
search_filters: Optional[ClinVarSearchFilters] = Field(default=None, description="Applied search filters")
|
341 |
+
processing_time_ms: float = Field(default=0.0, ge=0, description="Search processing time in milliseconds")
|
342 |
+
|
343 |
+
# Quality metrics
|
344 |
+
average_star_rating: float = Field(default=0.0, ge=0.0, le=4.0, description="Average star rating of results")
|
345 |
+
pathogenic_count: int = Field(default=0, ge=0, description="Number of pathogenic/likely pathogenic variants")
|
346 |
+
benign_count: int = Field(default=0, ge=0, description="Number of benign/likely benign variants")
|
347 |
+
|
348 |
+
@property
|
349 |
+
def has_next_page(self) -> bool:
|
350 |
+
"""Check if there are more pages available."""
|
351 |
+
return self.page < self.total_pages
|
352 |
+
|
353 |
+
@property
|
354 |
+
def has_previous_page(self) -> bool:
|
355 |
+
"""Check if there are previous pages available."""
|
356 |
+
return self.page > 1
|
357 |
+
|
358 |
+
@property
|
359 |
+
def pathogenic_percentage(self) -> float:
|
360 |
+
"""Calculate percentage of pathogenic variants."""
|
361 |
+
if self.total_count == 0:
|
362 |
+
return 0.0
|
363 |
+
return (self.pathogenic_count / self.total_count) * 100
|
364 |
+
|
365 |
+
@property
|
366 |
+
def benign_percentage(self) -> float:
|
367 |
+
"""Calculate percentage of benign variants."""
|
368 |
+
if self.total_count == 0:
|
369 |
+
return 0.0
|
370 |
+
return (self.benign_count / self.total_count) * 100
|