File size: 3,434 Bytes
9f895e9
14a4318
1007582
 
14a4318
9f895e9
14a4318
 
 
 
1007582
 
 
9f895e9
 
 
1007582
 
 
 
 
 
9f895e9
14a4318
9f895e9
1007582
14a4318
1007582
 
 
 
 
 
 
 
 
 
 
14a4318
1007582
 
 
 
 
 
 
14a4318
1007582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14a4318
 
 
1007582
14a4318
1007582
 
 
 
14a4318
1007582
14a4318
1007582
 
 
14a4318
 
 
1007582
14a4318
 
 
 
1007582
 
14a4318
1007582
 
 
 
 
14a4318
1007582
 
14a4318
1007582
 
 
 
 
 
 
 
14a4318
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
import os
from enum import Enum
from typing import List, Optional

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from gradio_client import Client
from pydantic import BaseModel, ConfigDict, Field, constr

from docs import description, tags_metadata

load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="ACRES RAG API",
    description=description,
    openapi_tags=tags_metadata,
)
GRADIO_URL = os.getenv("GRADIO_URL", "http://localhost:7860/")
logger.info(f"GRADIO_URL: {GRADIO_URL}")
client = Client(GRADIO_URL)


class StudyVariables(str, Enum):
    ebola_virus = "Ebola Virus"
    vaccine_coverage = "Vaccine coverage"
    genexpert = "GeneXpert"


class PromptType(str, Enum):
    default = "Default"
    highlight = "Highlight"
    evidence_based = "Evidence-based"


class StudyVariableRequest(BaseModel):
    study_variable: StudyVariables
    prompt_type: PromptType
    text: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


class DownloadCSV(BaseModel):
    text: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


class Study(BaseModel):
    study_name: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


class ZoteroCredentials(BaseModel):
    library_id: constr(min_length=1, strip_whitespace=True)  # type: ignore
    api_access_key: constr(min_length=1, strip_whitespace=True)  # type: ignore

    model_config = ConfigDict(from_attributes=True)


@app.post("/process_zotero_library_items", tags=["zotero"])
def process_zotero_library_items(zotero_credentials: ZoteroCredentials):
    result = client.predict(
        zotero_library_id=zotero_credentials.library_id,
        zotero_api_access_key=zotero_credentials.api_access_key,
        api_name="/process_zotero_library_items",
    )
    return {"result": result}


@app.post("/get_study_info", tags=["zotero"])
def get_study_info(study: Study):
    result = client.predict(study_name=study.study_name, api_name="/get_study_info")
    # print(result)
    return {"result": result}


@app.post("/study_variables", tags=["zotero"])
def process_study_variables(
    study_request: StudyVariableRequest,
):
    result = client.predict(
        text=study_request.text,  # "study id, study title, study design, study summary",
        study_name=study_request.study_variable,  # "Ebola Virus",
        prompt_type=study_request.prompt_type,  # "Default",
        api_name="/process_multi_input",
    )
    print(type(result))
    return {"result": result[0]}


@app.post("/download_csv", tags=["zotero"])
def download_csv(download_request: DownloadCSV):
    result = client.predict(
        markdown_content=download_request.text, api_name="/download_as_csv"
    )
    print(result)

    file_path = result
    if not file_path or not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found")

    # Use FileResponse to send the file to the client
    return FileResponse(
        file_path,
        media_type="text/csv",  # Specify the correct MIME type for CSV
        filename=os.path.basename(
            file_path
        ),  # Provide a default filename for the download
    )