[docs]classImageCaptionLoader(BaseLoader):"""Load image captions. By default, the loader utilizes the pre-trained Salesforce BLIP image captioning model. https://huggingface.co/Salesforce/blip-image-captioning-base """
[docs]def__init__(self,images:Union[str,Path,bytes,List[Union[str,bytes,Path]]],blip_processor:str="Salesforce/blip-image-captioning-base",blip_model:str="Salesforce/blip-image-captioning-base",):"""Initialize with a list of image data (bytes) or file paths Args: images: Either a single image or a list of images. Accepts image data (bytes) or file paths to images. blip_processor: The name of the pre-trained BLIP processor. blip_model: The name of the pre-trained BLIP model. """ifisinstance(images,(str,Path,bytes)):self.images=[images]else:self.images=imagesself.blip_processor=blip_processorself.blip_model=blip_model
[docs]defload(self)->List[Document]:"""Load from a list of image data or file paths"""try:fromtransformersimportBlipForConditionalGeneration,BlipProcessorexceptImportError:raiseImportError("`transformers` package not found, please install with ""`pip install transformers`.")processor=BlipProcessor.from_pretrained(self.blip_processor)model=BlipForConditionalGeneration.from_pretrained(self.blip_model)results=[]forimageinself.images:caption,metadata=self._get_captions_and_metadata(model=model,processor=processor,image=image)doc=Document(page_content=caption,metadata=metadata)results.append(doc)returnresults
def_get_captions_and_metadata(self,model:Any,processor:Any,image:Union[str,Path,bytes])->Tuple[str,dict]:"""Helper function for getting the captions and metadata of an image."""try:fromPILimportImageexceptImportError:raiseImportError("`PIL` package not found, please install with `pip install pillow`")image_source=image# Save the original source for later referencetry:ifisinstance(image,bytes):image=Image.open(BytesIO(image)).convert("RGB")# type: ignore[assignment]elifisinstance(image,str)and(image.startswith("http://")orimage.startswith("https://")):image=Image.open(requests.get(image,stream=True).raw).convert("RGB")# type: ignore[assignment, arg-type]else:image=Image.open(image).convert("RGB")# type: ignore[assignment]exceptException:ifisinstance(image_source,bytes):msg="Could not get image data from bytes"else:msg=f"Could not get image data for {image_source}"raiseValueError(msg)inputs=processor(image,"an image of",return_tensors="pt")output=model.generate(**inputs)caption:str=processor.decode(output[0])ifisinstance(image_source,bytes):metadata:dict={"image_source":"Image bytes provided"}else:metadata={"image_path":str(image_source)}returncaption,metadata