from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[
        0
    ]  
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )

def cosine_similarity(u, v):
    return F.cosine_similarity(u, v, dim=1)


def compare(text1, text2):

    sentences = [text1, text2]

    tokenizer = AutoTokenizer.from_pretrained("dmlls/all-mpnet-base-v2-negation")
    model = AutoModel.from_pretrained("dmlls/all-mpnet-base-v2-negation")

    encoded_input = tokenizer(
        sentences, padding=True, truncation=True, return_tensors="pt"
    )

    with torch.no_grad():
        model_output = model(**encoded_input)

    sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])

    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    similarity_score = cosine_similarity(
        sentence_embeddings[0].unsqueeze(0), sentence_embeddings[1].unsqueeze(0)
    )
    return similarity_score.item()


#--------------------------------------------------------------------------------------------------------------------
from fastapi import FastAPI

app = FastAPI()

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

#--------------------------------------------------------------------------------------------------------------------

from transformers import pipeline

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

def Summerized_Text(text):
    text = text.strip()
    a = summarizer(text, max_length=130, min_length=30, do_sample=False)
    print(a)
    return a[0]['summary_text']

#--------------------------------------------------------------------------------------------------------------------

from fastapi.responses import JSONResponse
from pydantic import BaseModel
from fastapi import FastAPI

class StrRequest(BaseModel):
    text: str


class CompareRequest(BaseModel):
    summary: str
    text: str


@app.get("/api/check")
def check_connection():
    try:
        return JSONResponse(
            {"status": 200, "message": "Message Successfully Sent"}, status_code=200
        )
    except Exception as e:
        print("Error => ", e)
        return JSONResponse({"status": 500, "message": str(e)}, status_code=500)


@app.post("/api/summerized")
async def get_summerized(request: StrRequest):
    try:
        print(request)
        text = request.text
        if not text:
            return JSONResponse(
                {"status": 422, "message": "Invalid Input"}, status_code=422
            )
        summary = Summerized_Text(text)
        if "No abstract text." in summary:
            return JSONResponse(
                {"status": 500, "message": "No matching text found", "data": "None"}
            )
            
        if not summary:
            return JSONResponse(
                {"status": 500, "message": "No matching text found", "data": {}}
            )

        return JSONResponse(
            {"status": 200, "message": "Matching text found", "data": summary}
        )

    except Exception as e:
        print("Error => ", e)
        return JSONResponse({"status": 500, "message": str(e)}, status_code=500)


@app.post("/api/compare")
def compareTexts(request: CompareRequest):
    try:
        text = request.text
        summary = request.summary
        if not summary or not text:
            return JSONResponse(
                {"status": 422, "message": "Invalid Input"}, status_code=422
            )
        value = compare(text, summary)
        return JSONResponse(
            {
                "status": 200,
                "message": "Comparisons made",
                "value": value,
                "text": text,
                "summary": summary,
            }
        )
    except Exception as e:
        print("Error => ", e)
        return JSONResponse({"status": 500, "message": str(e)}, status_code=500)