Charm_10 / charm_c10_nlp.py
GeminiFan207's picture
Upload 12 files
18fa92b verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import firebase_admin
from firebase_admin import credentials, firestore
import logging
# Import classes and functions from model.py
from inference.model import ModelArgs, Transformer
# Setup Firebase
cred = credentials.Certificate("firebase-config.json") # Replace with your Firebase Admin SDK JSON
firebase_admin.initialize_app(cred)
db = firestore.client()
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CharmC10NLP(nn.Module):
def __init__(
self,
model_name="meta-llama/Llama-2-13b",
freeze_encoder=False,
use_fp16=False,
max_new_tokens=200,
):
super(CharmC10NLP, self).__init__()
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Initialize ModelArgs
self.args = ModelArgs()
# Initialize Transformer model
self.model = Transformer(self.args)
# Freeze the model if required
if freeze_encoder:
for param in self.model.parameters():
param.requires_grad = False
logger.info("Model frozen for fine-tuning.")
# Mixed precision (FP16)
self.use_fp16 = use_fp16
if self.use_fp16 and torch.cuda.is_available():
logger.info("FP16 enabled for mixed precision training.")
# Generation settings
self.max_new_tokens = max_new_tokens
# Auto-detect device (GPU/CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
logger.info(f"Model running on {self.device}")
def generate_response(self, prompt, temperature=0.7, top_p=0.9):
"""Generates an AI response based on input prompt."""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True
).to(self.device)
with torch.no_grad():
output_ids = self.model.generate(
inputs["input_ids"],
max_new_tokens=self.max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Store AI response in Firestore
self.store_response(prompt, response)
return response
def store_response(self, prompt, response):
"""Saves the AI-generated response to Firestore."""
doc_ref = db.collection("chatflare_responses").document()
doc_ref.set({
"prompt": prompt,
"response": response,
"timestamp": firestore.SERVER_TIMESTAMP
})
logger.info("Response stored in Firestore.")