import os import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # Configurar caché y gestión de memoria os.environ["TRANSFORMERS_CACHE"] = "/root/.cache/huggingface/" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Nombre del modelo model_name = "BSC-LT/ALIA-40b" # Cargar modelo desde caché si es posible try: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"), local_files_only=True) model = AutoModelForCausalLM.from_pretrained( model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"), local_files_only=True, device_map="auto", offload_folder="offload_cache", torch_dtype=torch.bfloat16 ) print("Modelo cargado desde caché.") except Exception as e: print("El modelo no se encontró en caché. Descargando...") tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE")) model = AutoModelForCausalLM.from_pretrained( model_name, cache_dir=os.getenv("TRANSFORMERS_CACHE"), device_map="auto", offload_folder="offload_cache", torch_dtype=torch.bfloat16 ) tokenizer.save_pretrained("/root/model_storage/") model.save_pretrained("/root/model_storage/") print("Modelo guardado en caché para futuras cargas.") # Mostrar en qué dispositivo está el modelo print(f"Modelo cargado en: {next(model.parameters()).device}") def generar_texto(entrada): torch.cuda.empty_cache() # Liberar caché antes de inferencia input_ids = tokenizer(entrada, return_tensors="pt").input_ids.to("cuda") output = model.generate( input_ids, max_length=100, temperature=0.1, top_p=0.95, repetition_penalty=1.2, do_sample=True ) return tokenizer.decode(output[0], skip_special_tokens=True) # Crear la interfaz de Gradio interfaz = gr.Interface( fn=generar_texto, inputs=gr.Textbox(lines=2, placeholder="Escribe tu prompt aquí...", interactive=True), outputs=gr.Textbox(interactive=True), title="Generador de Texto con ALIA-40b", description="Este modelo genera texto utilizando ALIA-40b, un modelo LLM entrenado por BSC-LT." ) if __name__ == "__main__": interfaz.launch(share=True, server_port=7860)