library_name: keras-hub
pipeline_tag: text-generation
language:
- en
tags:
- gemma2b
- gemma
- google
- gemini
- gemma data science
- gemma 2b data science
- data science model
datasets:
- soufyane/DATA_SCIENCE_QA
This is a Gemma
model uploaded using the KerasNLP library and can be used with JAX, TensorFlow, and PyTorch backends.
This model is related to a CausalLM
task.
Model config:
name: gemma_backbone
trainable: True
vocabulary_size: 256000
num_layers: 18
num_query_heads: 8
num_key_value_heads: 1
hidden_dim: 2048
intermediate_dim: 32768
head_dim: 256
layer_norm_epsilon: 1e-06
dropout: 0
Model Details:
Architecture: Gemma 2b is based on a deep neural network architecture, utilizing state-of-the-art techniques in natural language processing and machine learning.
Fine-tuning Framework: Gemma 2b was fine-tuned using the Keras NLP framework, which provides powerful tools for building and training neural network models specifically tailored for natural language processing tasks.
Training Data: Gemma 2b was fine-tuned on a diverse set of data science datasets. https://huggingface.co/datasets/soufyane/DATA_SCIENCE_QA
Preprocessing: The model incorporates standard preprocessing techniques including tokenization, normalization, and feature scaling to handle input data effectively.
use it on kaggle: I recommend to use the model on kaggle(free GPU use P100) for fast responses here's the link to my notebook: https://www.kaggle.com/code/sufyen/gemma-2b-data-science-from-hugging-face
how to use:
#install the necessery PKGs
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
import keras_nlp
from keras_nlp.models import GemmaCausalLM
import warnings
warnings.filterwarnings('ignore')
import os
#set the envirenment
os.environ["KERAS_BACKEND"] = "jax" # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
#load the model from HF
model = keras_nlp.models.CausalLM.from_preset(f"hf://soufyane/gemma_data_science")
while True:
x = input("enter your question: ")
print(model.generate(f"question: {x}", max_length=256))