import os
import pickle
import torch
from PIL import Image
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    FluxPipeline,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
)
from transformers import (
    pipeline as transformers_pipeline,
    AutoModelForCausalLM,
    AutoTokenizer,
    GPT2Tokenizer,
    GPT2Model,
    AutoModel
)
from audiocraft.models import musicgen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import io
import time
from tqdm import tqdm
from google.cloud import storage
import json

hf_token = os.getenv("HF_TOKEN")
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")

HfFolder.save_token(hf_token)

storage_client = storage.Client.from_service_account_info(gcs_credentials)
bucket = storage_client.bucket(gcs_bucket_name)


def load_object_from_gcs(blob_name):
    blob = bucket.blob(blob_name)
    if blob.exists():
        return pickle.loads(blob.download_as_bytes())
    return None


def save_object_to_gcs(blob_name, obj):
    blob = bucket.blob(blob_name)
    blob.upload_from_string(pickle.dumps(obj))


def get_model_or_download(model_id, blob_name, loader_func):
    model = load_object_from_gcs(blob_name)
    if model:
        return model
    try:
        with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
            model = loader_func(model_id, torch_dtype=torch.float16)
            pbar.update(1)
        save_object_to_gcs(blob_name, model)
        return model
    except Exception as e:
        print(f"Failed to load or save model: {e}")
        return None


def generate_image(prompt):
    blob_name = f"diffusers/generated_image:{prompt}"
    image_bytes = load_object_from_gcs(blob_name)
    if not image_bytes:
        try:
            with tqdm(total=1, desc="Generating image") as pbar:
                image = text_to_image_pipeline(prompt).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            save_object_to_gcs(blob_name, image_bytes)
        except Exception as e:
            print(f"Failed to generate image: {e}")
            return None
    return image_bytes


def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
    blob_name = f"diffusers/edited_image:{prompt}:{strength}"
    edited_image_bytes = load_object_from_gcs(blob_name)
    if not edited_image_bytes:
        try:
            image = Image.open(io.BytesIO(image_bytes))
            with tqdm(total=1, desc="Editing image") as pbar:
                edited_image = img2img_pipeline(
                    prompt=prompt, image=image, strength=strength
                ).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            edited_image.save(buffered, format="JPEG")
            edited_image_bytes = buffered.getvalue()
            save_object_to_gcs(blob_name, edited_image_bytes)
        except Exception as e:
            print(f"Failed to edit image: {e}")
            return None
    return edited_image_bytes


def generate_song(prompt, duration=10):
    blob_name = f"music/generated_song:{prompt}:{duration}"
    song_bytes = load_object_from_gcs(blob_name)
    if not song_bytes:
        try:
            with tqdm(total=1, desc="Generating song") as pbar:
                song = music_gen(prompt, duration=duration)
                pbar.update(1)
            song_bytes = song[0].getvalue()
            save_object_to_gcs(blob_name, song_bytes)
        except Exception as e:
            print(f"Failed to generate song: {e}")
            return None
    return song_bytes


def generate_text(prompt):
    blob_name = f"transformers/generated_text:{prompt}"
    text = load_object_from_gcs(blob_name)
    if not text:
        try:
            with tqdm(total=1, desc="Generating text") as pbar:
                text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
                    "generated_text"
                ].strip()
                pbar.update(1)
            save_object_to_gcs(blob_name, text)
        except Exception as e:
            print(f"Failed to generate text: {e}")
            return None
    return text


def generate_flux_image(prompt):
    blob_name = f"diffusers/generated_flux_image:{prompt}"
    flux_image_bytes = load_object_from_gcs(blob_name)
    if not flux_image_bytes:
        try:
            with tqdm(total=1, desc="Generating FLUX image") as pbar:
                flux_image = flux_pipeline(
                    prompt,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_length=256,
                    generator=torch.Generator("cpu").manual_seed(0),
                ).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            flux_image.save(buffered, format="JPEG")
            flux_image_bytes = buffered.getvalue()
            save_object_to_gcs(blob_name, flux_image_bytes)
        except Exception as e:
            print(f"Failed to generate flux image: {e}")
            return None
    return flux_image_bytes


def generate_code(prompt):
    blob_name = f"transformers/generated_code:{prompt}"
    code = load_object_from_gcs(blob_name)
    if not code:
        try:
            with tqdm(total=1, desc="Generating code") as pbar:
                inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt")
                outputs = starcoder_model.generate(inputs, max_new_tokens=256)
                code = starcoder_tokenizer.decode(outputs[0])
                pbar.update(1)
            save_object_to_gcs(blob_name, code)
        except Exception as e:
            print(f"Failed to generate code: {e}")
            return None
    return code


def test_model_meta_llama():
    blob_name = "transformers/meta_llama_test_response"
    response = load_object_from_gcs(blob_name)
    if not response:
        try:
            messages = [
                {
                    "role": "system",
                    "content": "You are a pirate chatbot who always responds in pirate speak!",
                },
                {"role": "user", "content": "Who are you?"},
            ]
            with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
                response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
                    "generated_text"
                ].strip()
                pbar.update(1)
            save_object_to_gcs(blob_name, response)
        except Exception as e:
            print(f"Failed to test Meta-Llama: {e}")
            return None
    return response


def generate_image_sdxl(prompt):
    blob_name = f"diffusers/generated_image_sdxl:{prompt}"
    image_bytes = load_object_from_gcs(blob_name)
    if not image_bytes:
        try:
            with tqdm(total=1, desc="Generating SDXL image") as pbar:
                image = base(
                    prompt=prompt,
                    num_inference_steps=40,
                    denoising_end=0.8,
                    output_type="latent",
                ).images
                image = refiner(
                    prompt=prompt,
                    num_inference_steps=40,
                    denoising_start=0.8,
                    image=image,
                ).images[0]
                pbar.update(1)
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            save_object_to_gcs(blob_name, image_bytes)
        except Exception as e:
            print(f"Failed to generate SDXL image: {e}")
            return None
    return image_bytes


def generate_musicgen_melody(prompt):
    blob_name = f"music/generated_musicgen_melody:{prompt}"
    song_bytes = load_object_from_gcs(blob_name)
    if not song_bytes:
        try:
            with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
                melody, sr = torchaudio.load("./assets/bach.mp3")
                wav = music_gen_melody.generate_with_chroma(
                    [prompt], melody[None].expand(3, -1, -1), sr
                )
                pbar.update(1)
            song_bytes = wav[0].getvalue()
            save_object_to_gcs(blob_name, song_bytes)
        except Exception as e:
            print(f"Failed to generate MusicGen melody: {e}")
            return None
    return song_bytes


def generate_musicgen_large(prompt):
    blob_name = f"music/generated_musicgen_large:{prompt}"
    song_bytes = load_object_from_gcs(blob_name)
    if not song_bytes:
        try:
            with tqdm(total=1, desc="Generating MusicGen large") as pbar:
                wav = music_gen_large.generate([prompt])
                pbar.update(1)
            song_bytes = wav[0].getvalue()
            save_object_to_gcs(blob_name, song_bytes)
        except Exception as e:
            print(f"Failed to generate MusicGen large: {e}")
            return None
    return song_bytes


def transcribe_audio(audio_sample):
    blob_name = f"transformers/transcribed_audio:{hash(audio_sample.tobytes())}"
    text = load_object_from_gcs(blob_name)
    if not text:
        try:
            with tqdm(total=1, desc="Transcribing audio") as pbar:
                text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
                pbar.update(1)
            save_object_to_gcs(blob_name, text)
        except Exception as e:
            print(f"Failed to transcribe audio: {e}")
            return None
    return text


def generate_mistral_instruct(prompt):
    blob_name = f"transformers/generated_mistral_instruct:{prompt}"
    response = load_object_from_gcs(blob_name)
    if not response:
        try:
            conversation = [{"role": "user", "content": prompt}]
            with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
                inputs = mistral_instruct_tokenizer.apply_chat_template(
                    conversation,
                    tools=tools,
                    add_generation_prompt=True,
                    return_dict=True,
                    return_tensors="pt",
                )
                outputs = mistral_instruct_model.generate(
                    **inputs, max_new_tokens=1000
                )
                response = mistral_instruct_tokenizer.decode(
                    outputs[0], skip_special_tokens=True
                )
                pbar.update(1)
            save_object_to_gcs(blob_name, response)
        except Exception as e:
            print(f"Failed to generate Mistral Instruct response: {e}")
            return None
    return response


def generate_mistral_nemo(prompt):
    blob_name = f"transformers/generated_mistral_nemo:{prompt}"
    response = load_object_from_gcs(blob_name)
    if not response:
        try:
            conversation = [{"role": "user", "content": prompt}]
            with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
                inputs = mistral_nemo_tokenizer.apply_chat_template(
                    conversation,
                    tools=tools,
                    add_generation_prompt=True,
                    return_dict=True,
                    return_tensors="pt",
                )
                outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
                response = mistral_nemo_tokenizer.decode(
                    outputs[0], skip_special_tokens=True
                )
                pbar.update(1)
            save_object_to_gcs(blob_name, response)
        except Exception as e:
            print(f"Failed to generate Mistral Nemo response: {e}")
            return None
    return response


def generate_gpt2_xl(prompt):
    blob_name = f"transformers/generated_gpt2_xl:{prompt}"
    response = load_object_from_gcs(blob_name)
    if not response:
        try:
            with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
                inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
                outputs = gpt2_xl_model(**inputs)
                response = gpt2_xl_tokenizer.decode(
                    outputs[0][0], skip_special_tokens=True
                )
                pbar.update(1)
            save_object_to_gcs(blob_name, response)
        except Exception as e:
            print(f"Failed to generate GPT-2 XL response: {e}")
            return None
    return response


def store_user_question(question):
    blob_name = "user_questions.txt"
    blob = bucket.blob(blob_name)
    if blob.exists():
        blob.download_to_filename("user_questions.txt")
    with open("user_questions.txt", "a") as f:
        f.write(question + "\n")
    blob.upload_from_filename("user_questions.txt")


def retrain_models():
    pass


def generate_text_to_video_ms_1_7b(prompt, num_frames=200):
    blob_name = f"diffusers/text_to_video_ms_1_7b:{prompt}:{num_frames}"
    video_bytes = load_object_from_gcs(blob_name)
    if not video_bytes:
        try:
            with tqdm(total=1, desc="Generating video") as pbar:
                video_frames = text_to_video_ms_1_7b_pipeline(
                    prompt, num_inference_steps=25, num_frames=num_frames
                ).frames
                pbar.update(1)
            video_path = export_to_video(video_frames)
            with open(video_path, "rb") as f:
                video_bytes = f.read()
            save_object_to_gcs(blob_name, video_bytes)
            os.remove(video_path)
        except Exception as e:
            print(f"Failed to generate video: {e}")
            return None
    return video_bytes


def generate_text_to_video_ms_1_7b_short(prompt):
    blob_name = f"diffusers/text_to_video_ms_1_7b_short:{prompt}"
    video_bytes = load_object_from_gcs(blob_name)
    if not video_bytes:
        try:
            with tqdm(total=1, desc="Generating short video") as pbar:
                video_frames = text_to_video_ms_1_7b_short_pipeline(
                    prompt, num_inference_steps=25
                ).frames
                pbar.update(1)
            video_path = export_to_video(video_frames)
            with open(video_path, "rb") as f:
                video_bytes = f.read()
            save_object_to_gcs(blob_name, video_bytes)
            os.remove(video_path)
        except Exception as e:
            print(f"Failed to generate short video: {e}")
            return None
    return video_bytes


text_to_image_pipeline = get_model_or_download(
    "stabilityai/stable-diffusion-2",
    "diffusers/text_to_image_model",
    StableDiffusionPipeline.from_pretrained,
)
img2img_pipeline = get_model_or_download(
    "CompVis/stable-diffusion-v1-4",
    "diffusers/img2img_model",
    StableDiffusionImg2ImgPipeline.from_pretrained,
)
flux_pipeline = get_model_or_download(
    "black-forest-labs/FLUX.1-schnell",
    "diffusers/flux_model",
    FluxPipeline.from_pretrained,
)
text_gen_pipeline = transformers_pipeline(
    "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
)
music_gen = (
    load_object_from_gcs("music/music_gen")
    or musicgen.MusicGen.get_pretrained("melody")
)
meta_llama_pipeline = get_model_or_download(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "transformers/meta_llama_model",
    transformers_pipeline,
)
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")

base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)
refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)
music_gen_melody = musicgen.MusicGen.get_pretrained("melody")
music_gen_melody.set_generation_params(duration=8)
music_gen_large = musicgen.MusicGen.get_pretrained("large")
music_gen_large.set_generation_params(duration=8)
whisper_pipeline = transformers_pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-small",
    chunk_length_s=30,
)
mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-Large-Instruct-2407",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-Large-Instruct-2407"
)
mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-Nemo-Instruct-2407",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-Nemo-Instruct-2407"
)
gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")

llama_3_groq_70b_tool_use_pipeline = transformers_pipeline(
    "text-generation", model="Groq/Llama-3-Groq-70B-Tool-Use"
)
phi_3_5_mini_instruct_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3.5-mini-instruct", torch_dtype="auto", trust_remote_code=True
)
phi_3_5_mini_instruct_tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3.5-mini-instruct"
)
phi_3_5_mini_instruct_pipeline = transformers_pipeline(
    "text-generation",
    model=phi_3_5_mini_instruct_model,
    tokenizer=phi_3_5_mini_instruct_tokenizer,
)
meta_llama_3_1_8b_pipeline = transformers_pipeline(
    "text-generation",
    model="meta-llama/Meta-Llama-3.1-8B",
    model_kwargs={"torch_dtype": torch.bfloat16},
)
meta_llama_3_1_70b_pipeline = transformers_pipeline(
    "text-generation",
    model="meta-llama/Meta-Llama-3.1-70B",
    model_kwargs={"torch_dtype": torch.bfloat16},
)
medical_text_summarization_pipeline = transformers_pipeline(
    "summarization", model="your/medical_text_summarization_model"
)
bart_large_cnn_summarization_pipeline = transformers_pipeline(
    "summarization", model="facebook/bart-large-cnn"
)
flux_1_dev_pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
flux_1_dev_pipeline.enable_model_cpu_offload()
gemma_2_9b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b")
gemma_2_9b_it_pipeline = transformers_pipeline(
    "text-generation",
    model="google/gemma-2-9b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
)
gemma_2_2b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-2b")
gemma_2_2b_it_pipeline = transformers_pipeline(
    "text-generation",
    model="google/gemma-2-2b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
)
gemma_2_27b_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
gemma_2_27b_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-27b")
gemma_2_27b_it_pipeline = transformers_pipeline(
    "text-generation",
    model="google/gemma-2-27b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
)
text_to_video_ms_1_7b_pipeline = DiffusionPipeline.from_pretrained(
    "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
)
text_to_video_ms_1_7b_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
    text_to_video_ms_1_7b_pipeline.scheduler.config
)
text_to_video_ms_1_7b_pipeline.enable_model_cpu_offload()
text_to_video_ms_1_7b_pipeline.enable_vae_slicing()
text_to_video_ms_1_7b_short_pipeline = DiffusionPipeline.from_pretrained(
    "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
)
text_to_video_ms_1_7b_short_pipeline.scheduler = (
    DPMSolverMultistepScheduler.from_config(
        text_to_video_ms_1_7b_short_pipeline.scheduler.config
    )
)
text_to_video_ms_1_7b_short_pipeline.enable_model_cpu_offload()

tools = []

gen_image_tab = gr.Interface(
    fn=generate_image,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Image(type="pil"),
    title="Generate Image",
)
edit_image_tab = gr.Interface(
    fn=edit_image_with_prompt,
    inputs=[
        gr.Image(type="pil", label="Image:"),
        gr.Textbox(label="Prompt:"),
        gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
    ],
    outputs=gr.Image(type="pil"),
    title="Edit Image",
)
generate_song_tab = gr.Interface(
    fn=generate_song,
    inputs=[
        gr.Textbox(label="Prompt:"),
        gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
    ],
    outputs=gr.Audio(type="numpy"),
    title="Generate Songs",
)
generate_text_tab = gr.Interface(
    fn=generate_text,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Generated Text:"),
    title="Generate Text",
)
generate_flux_image_tab = gr.Interface(
    fn=generate_flux_image,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Image(type="pil"),
    title="Generate FLUX Images",
)
generate_code_tab = gr.Interface(
    fn=generate_code,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Generated Code:"),
    title="Generate Code",
)
model_meta_llama_test_tab = gr.Interface(
    fn=test_model_meta_llama,
    inputs=None,
    outputs=gr.Textbox(label="Model Output:"),
    title="Test Meta-Llama",
)
generate_image_sdxl_tab = gr.Interface(
    fn=generate_image_sdxl,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Image(type="pil"),
    title="Generate SDXL Image",
)
generate_musicgen_melody_tab = gr.Interface(
    fn=generate_musicgen_melody,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Audio(type="numpy"),
    title="Generate MusicGen Melody",
)
generate_musicgen_large_tab = gr.Interface(
    fn=generate_musicgen_large,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Audio(type="numpy"),
    title="Generate MusicGen Large",
)
transcribe_audio_tab = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(type="numpy", label="Audio Sample:"),
    outputs=gr.Textbox(label="Transcribed Text:"),
    title="Transcribe Audio",
)
generate_mistral_instruct_tab = gr.Interface(
    fn=generate_mistral_instruct,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Mistral Instruct Response:"),
    title="Generate Mistral Instruct Response",
)
generate_mistral_nemo_tab = gr.Interface(
    fn=generate_mistral_nemo,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="Mistral Nemo Response:"),
    title="Generate Mistral Nemo Response",
)
generate_gpt2_xl_tab = gr.Interface(
    fn=generate_gpt2_xl,
    inputs=gr.Textbox(label="Prompt:"),
    outputs=gr.Textbox(label="GPT-2 XL Response:"),
    title="Generate GPT-2 XL Response",
)
answer_question_minicpm_tab = gr.Interface(
    fn=answer_question_minicpm,
    inputs=[
        gr.Image(type="pil", label="Image:"),
        gr.Textbox(label="Question:"),
    ],
    outputs=gr.Textbox(label="MiniCPM Answer:"),
    title="Answer Question with MiniCPM",
)
llama_3_groq_70b_tool_use_tab = gr.Interface(
    fn=llama_3_groq_70b_tool_use_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Llama 3 Groq 70B Tool Use Response:"),
    title="Llama 3 Groq 70B Tool Use",
)
phi_3_5_mini_instruct_tab = gr.Interface(
    fn=phi_3_5_mini_instruct_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Phi 3.5 Mini Instruct Response:"),
    title="Phi 3.5 Mini Instruct",
)
meta_llama_3_1_8b_tab = gr.Interface(
    fn=meta_llama_3_1_8b_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Meta Llama 3.1 8B Response:"),
    title="Meta Llama 3.1 8B",
)
meta_llama_3_1_70b_tab = gr.Interface(
    fn=meta_llama_3_1_70b_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Meta Llama 3.1 70B Response:"),
    title="Meta Llama 3.1 70B",
)
medical_text_summarization_tab = gr.Interface(
    fn=medical_text_summarization_pipeline,
    inputs=[gr.Textbox(label="Medical Document:")],
    outputs=gr.Textbox(label="Medical Text Summarization:"),
    title="Medical Text Summarization",
)
bart_large_cnn_summarization_tab = gr.Interface(
    fn=bart_large_cnn_summarization_pipeline,
    inputs=[gr.Textbox(label="Article:")],
    outputs=gr.Textbox(label="Bart Large CNN Summarization:"),
    title="Bart Large CNN Summarization",
)
flux_1_dev_tab = gr.Interface(
    fn=flux_1_dev_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Image(type="pil"),
    title="FLUX 1 Dev",
)
gemma_2_9b_tab = gr.Interface(
    fn=gemma_2_9b_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Gemma 2 9B Response:"),
    title="Gemma 2 9B",
)
gemma_2_9b_it_tab = gr.Interface(
    fn=gemma_2_9b_it_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Gemma 2 9B IT Response:"),
    title="Gemma 2 9B IT",
)
gemma_2_2b_tab = gr.Interface(
    fn=gemma_2_2b_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Gemma 2 2B Response:"),
    title="Gemma 2 2B",
)
gemma_2_2b_it_tab = gr.Interface(
    fn=gemma_2_2b_it_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Gemma 2 2B IT Response:"),
    title="Gemma 2 2B IT",
)


def generate_gemma_2_27b(prompt):
    input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt")
    outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32)
    return gemma_2_27b_tokenizer.decode(outputs[0])


gemma_2_27b_tab = gr.Interface(
    fn=generate_gemma_2_27b,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Gemma 2 27B Response:"),
    title="Gemma 2 27B",
)
gemma_2_27b_it_tab = gr.Interface(
    fn=gemma_2_27b_it_pipeline,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Textbox(label="Gemma 2 27B IT Response:"),
    title="Gemma 2 27B IT",
)
text_to_video_ms_1_7b_tab = gr.Interface(
    fn=generate_text_to_video_ms_1_7b,
    inputs=[
        gr.Textbox(label="Prompt:"),
        gr.Slider(50, 200, 200, step=1, label="Number of Frames:"),
    ],
    outputs=gr.Video(),
    title="Text to Video MS 1.7B",
)
text_to_video_ms_1_7b_short_tab = gr.Interface(
    fn=generate_text_to_video_ms_1_7b_short,
    inputs=[gr.Textbox(label="Prompt:")],
    outputs=gr.Video(),
    title="Text to Video MS 1.7B Short",
)

app = gr.TabbedInterface(
    [
        gen_image_tab,
        edit_image_tab,
        generate_song_tab,
        generate_text_tab,
        generate_flux_image_tab,
        generate_code_tab,
        model_meta_llama_test_tab,
        generate_image_sdxl_tab,
        generate_musicgen_melody_tab,
        generate_musicgen_large_tab,
        transcribe_audio_tab,
        generate_mistral_instruct_tab,
        generate_mistral_nemo_tab,
        generate_gpt2_xl_tab,
        llama_3_groq_70b_tool_use_tab,
        phi_3_5_mini_instruct_tab,
        meta_llama_3_1_8b_tab,
        meta_llama_3_1_70b_tab,
        medical_text_summarization_tab,
        bart_large_cnn_summarization_tab,
        flux_1_dev_tab,
        gemma_2_9b_tab,
        gemma_2_9b_it_tab,
        gemma_2_2b_tab,
        gemma_2_2b_it_tab,
        gemma_2_27b_tab,
        gemma_2_27b_it_tab,
        text_to_video_ms_1_7b_tab,
        text_to_video_ms_1_7b_short_tab,
    ],
    [
        "Generate Image",
        "Edit Image",
        "Generate Song",
        "Generate Text",
        "Generate FLUX Image",
        "Generate Code",
        "Test Meta-Llama",
        "Generate SDXL Image",
        "Generate MusicGen Melody",
        "Generate MusicGen Large",
        "Transcribe Audio",
        "Generate Mistral Instruct Response",
        "Generate Mistral Nemo Response",
        "Generate GPT-2 XL Response",
        "Llama 3 Groq 70B Tool Use",
        "Phi 3.5 Mini Instruct",
        "Meta Llama 3.1 8B",
        "Meta Llama 3.1 70B",
        "Medical Text Summarization",
        "Bart Large CNN Summarization",
        "FLUX 1 Dev",
        "Gemma 2 9B",
        "Gemma 2 9B IT",
        "Gemma 2 2B",
        "Gemma 2 2B IT",
        "Gemma 2 27B",
        "Gemma 2 27B IT",
        "Text to Video MS 1.7B",
        "Text to Video MS 1.7B Short",
    ],
)

app.launch(share=True)