Joint Image & Text Embeddings

Published

January 1, 2022

Open In Colab

Installation

First we need to install sentence-transformers

!pip install sentence-transformers
Collecting sentence-transformers
  Downloading sentence-transformers-2.1.0.tar.gz (78 kB)
     |████████████████████████████████| 78 kB 3.4 MB/s 
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.12.3-py3-none-any.whl (3.1 MB)
     |████████████████████████████████| 3.1 MB 11.3 MB/s 
Collecting tokenizers>=0.10.3
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
     |████████████████████████████████| 3.3 MB 28.9 MB/s 
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (4.62.3)
Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (1.10.0+cu111)
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (0.11.1+cu111)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (1.19.5)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (1.0.1)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (1.4.1)
Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence-transformers) (3.2.5)
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
     |████████████████████████████████| 1.2 MB 34.2 MB/s 
Collecting huggingface-hub
  Downloading huggingface_hub-0.1.2-py3-none-any.whl (59 kB)
     |████████████████████████████████| 59 kB 6.1 MB/s 
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.6.0->sentence-transformers) (3.10.0.2)
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
     |████████████████████████████████| 895 kB 38.8 MB/s 
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (21.2)
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
     |████████████████████████████████| 596 kB 42.8 MB/s 
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (4.8.2)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (3.3.2)
Requirement already satisfied: pyparsing<3,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers<5.0.0,>=4.6.0->sentence-transformers) (2.4.7)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers<5.0.0,>=4.6.0->sentence-transformers) (3.6.0)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence-transformers) (1.15.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers<5.0.0,>=4.6.0->sentence-transformers) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers<5.0.0,>=4.6.0->sentence-transformers) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers<5.0.0,>=4.6.0->sentence-transformers) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers<5.0.0,>=4.6.0->sentence-transformers) (3.0.4)
Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers<5.0.0,>=4.6.0->sentence-transformers) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers<5.0.0,>=4.6.0->sentence-transformers) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence-transformers) (3.0.0)
Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence-transformers) (7.1.2)
Building wheels for collected packages: sentence-transformers
  Building wheel for sentence-transformers (setup.py) ... done
  Created wheel for sentence-transformers: filename=sentence_transformers-2.1.0-py3-none-any.whl size=121000 sha256=f7a209998618704c344dad092c6593fe0ad32927fe12546dab31670c8027aaaa
  Stored in directory: /root/.cache/pip/wheels/90/f0/bb/ed1add84da70092ea526466eadc2bfb197c4bcb8d4fa5f7bad
Successfully built sentence-transformers
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers, sentencepiece, sentence-transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed huggingface-hub-0.1.2 pyyaml-6.0 sacremoses-0.0.46 sentence-transformers-2.1.0 sentencepiece-0.1.96 tokenizers-0.10.3 transformers-4.12.3

Load CLIP Model

Next we load the CLIP model using SentenceTransformer. The model is downloaded automatically.

import sentence_transformers
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import glob
import torch
import pickle
import zipfile
from IPython.display import display
from IPython.display import Image as IPImage
import os
from tqdm.autonotebook import tqdm
torch.set_num_threads(4)

#First, we load the respective CLIP model
model = SentenceTransformer('clip-ViT-B-32')
ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.
# Next, we get about 25k images from Unsplash 
img_folder = 'photos/'
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
    os.makedirs(img_folder, exist_ok=True)
    
    photo_filename = 'unsplash-25k-photos.zip'
    if not os.path.exists(photo_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)
        
    #Extract all images
    with zipfile.ZipFile(photo_filename, 'r') as zf:
        for member in tqdm(zf.infolist(), desc='Extracting'):
            zf.extract(member, img_folder)
        
# Now, we need to compute the embeddings
# To speed things up, we destribute pre-computed embeddings
# Otherwise you can also encode the images yourself.
# To encode an image, you can use the following code:
# from PIL import Image
# img_emb = model.encode(Image.open(filepath))

use_precomputed_embeddings = True

if use_precomputed_embeddings: 
    emb_filename = 'unsplash-25k-photos-embeddings.pkl'
    if not os.path.exists(emb_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename)
        
    with open(emb_filename, 'rb') as fIn:
        img_names, img_emb = pickle.load(fIn)  
    print("Images:", len(img_names))
else:
    img_names = list(glob.glob('unsplash/photos/*.jpg'))
    print("Images:", len(img_names))
    img_emb = model.encode([Image.open(filepath) for filepath in img_names], batch_size=128, convert_to_tensor=True, show_progress_bar=True)
Images: 24996
# Next, we define a search function.
def search(query, k=3):
    # First, we encode the query (which can either be an image or a text string)
    query_emb = model.encode([query], convert_to_tensor=True, show_progress_bar=False)
    
    # Then, we use the util.semantic_search function, which computes the cosine-similarity
    # between the query embedding and all image embeddings.
    # It then returns the top_k highest ranked images, which we output
    hits = util.semantic_search(query_emb, img_emb, top_k=k)[0]
    
    print("Query:")
    display(query)
    for hit in hits:
        print(img_names[hit['corpus_id']])
        display(IPImage(os.path.join(img_folder, img_names[hit['corpus_id']]), width=200))
search("Two cats playing on the street")
Query:
4mA9_5vbZ_s.jpg
w6tMRf7kGLA.jpg
n4pNuXxyIr4.jpg
'Two cats playing on the street'

search("A sunset on the montain")
Query:
Zf4jpcGEinM.jpg
G5JDRSKi3uY.jpg
ig9yVlj5YYg.jpg
'A sunset on the montain'

search("Oslo")
Query:
uHsQou9tWTQ.jpg
0ABCZ9bTsw4.jpg
d5_hjWQ4NwA.jpg
'Oslo'

search("A dog in a park")
Query:
IVyZrLp41D0.jpg
0O9A0F_d1qA.jpg
KVeogBZzl4M.jpg
'A dog in a park'

search("A beach with palm trees")
Query:
7rrgPPljqYU.jpg
kmihWgpbDEg.jpg
ZyfOq52b0cs.jpg
'A beach with palm trees'