Spaces:
Running
Running
File size: 3,434 Bytes
9f895e9 14a4318 1007582 14a4318 9f895e9 14a4318 1007582 9f895e9 1007582 9f895e9 14a4318 9f895e9 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 1007582 14a4318 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import logging
import os
from enum import Enum
from typing import List, Optional
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from gradio_client import Client
from pydantic import BaseModel, ConfigDict, Field, constr
from docs import description, tags_metadata
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="ACRES RAG API",
description=description,
openapi_tags=tags_metadata,
)
GRADIO_URL = os.getenv("GRADIO_URL", "http://localhost:7860/")
logger.info(f"GRADIO_URL: {GRADIO_URL}")
client = Client(GRADIO_URL)
class StudyVariables(str, Enum):
ebola_virus = "Ebola Virus"
vaccine_coverage = "Vaccine coverage"
genexpert = "GeneXpert"
class PromptType(str, Enum):
default = "Default"
highlight = "Highlight"
evidence_based = "Evidence-based"
class StudyVariableRequest(BaseModel):
study_variable: StudyVariables
prompt_type: PromptType
text: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
class DownloadCSV(BaseModel):
text: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
class Study(BaseModel):
study_name: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
class ZoteroCredentials(BaseModel):
library_id: constr(min_length=1, strip_whitespace=True) # type: ignore
api_access_key: constr(min_length=1, strip_whitespace=True) # type: ignore
model_config = ConfigDict(from_attributes=True)
@app.post("/process_zotero_library_items", tags=["zotero"])
def process_zotero_library_items(zotero_credentials: ZoteroCredentials):
result = client.predict(
zotero_library_id=zotero_credentials.library_id,
zotero_api_access_key=zotero_credentials.api_access_key,
api_name="/process_zotero_library_items",
)
return {"result": result}
@app.post("/get_study_info", tags=["zotero"])
def get_study_info(study: Study):
result = client.predict(study_name=study.study_name, api_name="/get_study_info")
# print(result)
return {"result": result}
@app.post("/study_variables", tags=["zotero"])
def process_study_variables(
study_request: StudyVariableRequest,
):
result = client.predict(
text=study_request.text, # "study id, study title, study design, study summary",
study_name=study_request.study_variable, # "Ebola Virus",
prompt_type=study_request.prompt_type, # "Default",
api_name="/process_multi_input",
)
print(type(result))
return {"result": result[0]}
@app.post("/download_csv", tags=["zotero"])
def download_csv(download_request: DownloadCSV):
result = client.predict(
markdown_content=download_request.text, api_name="/download_as_csv"
)
print(result)
file_path = result
if not file_path or not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found")
# Use FileResponse to send the file to the client
return FileResponse(
file_path,
media_type="text/csv", # Specify the correct MIME type for CSV
filename=os.path.basename(
file_path
), # Provide a default filename for the download
)
|