AI-ANK commited on
Commit
9149fd8
·
1 Parent(s): f21d103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -15
app.py CHANGED
@@ -10,6 +10,8 @@ from llama_index import ServiceContext, VectorStoreIndex, Document
10
  from llama_index.memory import ChatMemoryBuffer
11
  import os
12
  import datetime
 
 
13
 
14
  # Set up the title of the application
15
  st.title("Image Captioning and Chat")
@@ -28,13 +30,14 @@ model, processor = get_vision_model()
28
  # Function to get image caption via Kosmos2.
29
  @st.cache_data
30
  def get_image_caption(image_data):
 
 
31
 
32
  model, processor = get_vision_model()
33
- #model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
34
- #processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
35
 
36
  prompt = "<grounding>An image of"
37
- inputs = processor(text=prompt, images=image_data, return_tensors="pt")
 
38
 
39
  generated_ids = model.generate(
40
  pixel_values=inputs["pixel_values"],
@@ -48,18 +51,7 @@ def get_image_caption(image_data):
48
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
 
50
  text_description, entities = processor.post_process_generation(generated_text)
51
-
52
- #Using replicate API
53
- # input_data = {
54
- # "image": image_data,
55
- # "description_type": "Brief"
56
- # }
57
- # output = replicate.run(
58
- # "lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805",
59
- # input=input_data
60
- # )
61
- # # Split the output string on the newline character and take the first item
62
- # text_description = output.split('\n\n')[0]
63
  return text_description
64
 
65
  # Function to create the chat engine.
 
10
  from llama_index.memory import ChatMemoryBuffer
11
  import os
12
  import datetime
13
+ from PIL import Image
14
+ import io
15
 
16
  # Set up the title of the application
17
  st.title("Image Captioning and Chat")
 
30
  # Function to get image caption via Kosmos2.
31
  @st.cache_data
32
  def get_image_caption(image_data):
33
+ # Convert BytesIO to PIL Image
34
+ image = Image.open(io.BytesIO(image_data))
35
 
36
  model, processor = get_vision_model()
 
 
37
 
38
  prompt = "<grounding>An image of"
39
+ # Pass the PIL image to the processor
40
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
41
 
42
  generated_ids = model.generate(
43
  pixel_values=inputs["pixel_values"],
 
51
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
 
53
  text_description, entities = processor.post_process_generation(generated_text)
54
+
 
 
 
 
 
 
 
 
 
 
 
55
  return text_description
56
 
57
  # Function to create the chat engine.