root/wavelet.py

Revision 57, 4.0 kB (checked in by steve, 7 months ago)

fixed delete behaviour so the trash dir is made in the image base folder, not the working folder
refactored code into a class
cleaned up signature comparisong code
sped up signature comparison code

Line 
1 #!/usr/bin/env python
2 # encoding: utf-8
3 """
4 wavelet.py
5
6 Created by Shu Ning Bian on 2007-09-22.
7 Copyright (c) 2007 . All rights reserved.
8 Licensed for distribution under the GPL version 2, check COPYING for details.
9
10 Signature returned by this class is a dictionary with key of (band, location) value of
11 either 1 or -1. Location is z = y*columns + x.
12 """
13
14 import sys
15 import os
16 import Image
17
18 import config
19
20 #import pyx
21
22 def open(path):
23         """opens a file as a WaveletImage"""
24         return WaveletImage(path)
25        
26 class WaveletImage(object):
27         """A lazy wavelet transform class"""
28         def __init__(self, path):
29                 super(WaveletImage, self).__init__()
30                
31                 """
32                 Load the image from given path, using PIL, converts it into YIQ colour space
33                 """
34                 rgb2yiq = (
35                          0.299,  0.587,  0.114, 0,
36                          0.595716,      -0.274453,      -0.321263,      0,
37                          0.211456,      -0.522591,       0.311135,      0
38                         )
39                        
40                 rgb2yuv = (
41                          0.299,  0.587,  0.114,  0,
42                         -0.14713,        -0.28886,       0.436,  0,
43                          0.615, -0.51498,       -0.10001,        0,
44                         )
45                        
46                 im = Image.open(path)
47                 self.size = im.size
48                 im.thumbnail(config.img_size,Image.ANTIALIAS)
49                 im = im.convert("RGB")
50                
51                 self.data = im.convert("RGB", rgb2yiq).getdata()
52                 #self.data = im.convert("RGB", rgb2yuv).getdata()
53                 self.im = im
54                 self.wavelets = None
55                 self.signature = None
56        
57         def cleanup(self):
58                 """
59                 Frees wavelets data and reverts image to rgb. This method should only be called
60                 *after* calling .signature() at least once
61                 """
62                 self.wavelets = None
63                 self.data = None
64                
65         def pix_sum(self, x,y):
66                 """Returns a tuple which is the sum of the tuples given"""
67                 return (x[0]+y[0], x[1]+y[1], x[2]+y[2])
68        
69         def pix_diff(self, x,y):
70                 """Returns a tuple which is the difference of the tuples given"""
71                 return (x[0]-y[0], x[1]-y[1], x[2]-y[2])
72        
73         def get_signature(self):
74                 """Returns a signature tuple based on the input which is expected to be the
75                 wavelet transform of an image
76                 """             
77                 if not self.wavelets:
78                         self.transform()
79                 if self.signature:
80                         return self.signature
81
82                 self.signature = {}
83                 input = self.wavelets   
84                 length=len(input)/8
85        
86                 tmp=[0]*length
87                 for band in xrange(3):
88                         # copy the values in the current band into a tmp buffer
89                         for i in xrange(length):
90                                 tmp[i] = (input[i][band], i)
91                        
92                         sig = []
93                        
94                         # sorting by x[0]^2 because we want both negative and positive significant
95                         # coefficients
96                         tmp.sort(key = lambda x: x[0]*x[0])
97                        
98                         # take config.taps number of significan coefficients if we have enough
99                         if length > config.taps:
100                                 sig = tmp[:config.taps]
101                         else:
102                                 # otherwise take the entire tmp
103                                 sig = tmp[:]
104                        
105                         # normalise the signatures, research shows this is better       
106                         for s in sig:
107                                 v = s[0]
108                                 l = s[1]
109                                
110                                 # clamp to [-1,1]
111                                 v = max(-1, min(1, v))
112                                 k = (band, l)
113                                 self.signature[k] = v
114
115                 return self.signature
116        
117         def transform_array(self, input):
118                 """
119                 Performs wavelet transform on the input, destorys input,
120                 see http://en.wikipedia.org/wiki/Discrete_wavelet_transform
121                 """
122                 length = len(input)
123                 output = [0]*length
124                
125                 while(True):
126                         length/=2
127                        
128                         for i in xrange(length):
129                                 x = input[i*2]
130                                 y = input[i*2+1]
131                                 output[i] = (x[0]+y[0], x[1]+y[1], x[2]+y[2])
132                                 output[length+i] = (x[0]-y[0], x[1]-y[1], x[2]-y[2])
133
134                         if length <= 1:
135                                 return output
136                        
137                         input = output[:length*2]
138
139                 raise Exception
140                
141         def transform(self):
142                 """
143                 Performs a wavelet transform on the given image
144                 """
145                 assert self.data != None, "Did you call .cleanup() before .signature()?"
146                 input = list(self.data)
147                
148                 self.wavelets = self.transform_array(input)
149                 return
150                
151         def compare(self, other):
152                 """Compars this wavelet transform to another, returning a tuple of similarness"""
153                 sig1 = other.get_signature()
154                 sig2 = self.get_signature()
155
156                 return signature_compare(sig1, sig2)
157
158                        
159 def     signature_compare(sig1, sig2):
160         score = 0
161         for key,value in sig1.items():
162                 if sig2.has_key(key):
163                         if sig2[key] == value:
164                                 score += config.weights[key[0]]
165                         else:
166                                 score += config.weights[key[0]] * 0.5
167
168                
169         return score
170
171 def main():
172         pass
173
174
175 if __name__ == '__main__':
176         main()
177
Note: See TracBrowser for help on using the browser.