zainimam commited on
Commit
3cdb7c9
·
verified ·
1 Parent(s): 222fb60

Updated main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -14
main.py CHANGED
@@ -1,22 +1,47 @@
 
 
1
  from transformers import AutoModelForCausalLM, AutoProcessor
2
  from PIL import Image
3
  import requests
 
4
 
5
- # Load the processor and model
6
- processor = AutoProcessor.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True, device_map='auto')
7
- model = AutoModelForCausalLM.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True, device_map='auto')
8
 
9
- # Download an image
10
- image_url = "https://picsum.photos/id/237/536/354"
11
- image = Image.open(requests.get(image_url, stream=True).raw)
12
 
13
- # Process the image with some input text
14
- inputs = processor(images=[image], text="Describe this image.")
15
- inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
16
 
17
- # Generate text based on the input
18
- output = model.generate_from_batch(inputs, max_new_tokens=200)
 
 
19
 
20
- # Decode and print the generated text
21
- generated_text = processor.tokenizer.decode(output[0], skip_special_tokens=True)
22
- print(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoProcessor
4
  from PIL import Image
5
  import requests
6
+ import torch
7
 
8
+ # Define the FastAPI app
9
+ app = FastAPI()
 
10
 
11
+ # Initialize model and processor at startup
12
+ processor = AutoProcessor.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True)
13
+ model = AutoModelForCausalLM.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True)
14
 
15
+ # Move the model to GPU if available
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
 
19
+ # Request body structure
20
+ class GenerateRequest(BaseModel):
21
+ image_url: str
22
+ text_input: str
23
 
24
+ # API root endpoint
25
+ @app.get("/")
26
+ def root():
27
+ return {"message": "Molmo-7B-D API is up and running!"}
28
+
29
+ # Text generation endpoint
30
+ @app.post("/generate/")
31
+ def generate_text(request: GenerateRequest):
32
+ try:
33
+ # Fetch image from URL
34
+ response = requests.get(request.image_url, stream=True)
35
+ image = Image.open(response.raw)
36
+
37
+ # Preprocess inputs
38
+ inputs = processor(images=[image], text=request.text_input, return_tensors="pt").to(device)
39
+
40
+ # Generate text
41
+ output_ids = model.generate(inputs["input_ids"], max_new_tokens=200)
42
+ generated_text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
43
+
44
+ return {"generated_text": generated_text}
45
+
46
+ except Exception as e:
47
+ raise HTTPException(status_code=500, detail=str(e))