root/trunk/extra/tsp/src/tsp.c

Revision 338, 12.0 KB (checked in by anton, 14 months ago)

See #160

Line 
1/*
2 * Traveling Salesman Problem solution algorithm for PostgreSQL
3 *
4 * Copyright (c) 2006 Anton A. Patrushev, Orkney, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "tsp.h"
23
24#include "postgres.h"
25#include "executor/spi.h"
26#include "funcapi.h"
27#include "catalog/pg_type.h"
28
29#include "string.h"
30#include "math.h"
31
32#include "fmgr.h"
33
34#ifdef PG_MODULE_MAGIC
35PG_MODULE_MAGIC;
36#endif
37
38
39// ------------------------------------------------------------------------
40
41/*
42 * Define this to have profiling enabled
43 */
44//#define PROFILE
45
46#ifdef PROFILE
47#include <sys/time.h>
48
49struct timeval prof_tsp, prof_store, prof_extract, prof_total;
50long proftime[5];
51long profipts1, profipts2, profopts;
52#define profstart(x) do { gettimeofday(&x, NULL); } while (0);
53#define profstop(n, x) do { struct timeval _profstop;   \
54        long _proftime;                         \
55        gettimeofday(&_profstop, NULL);                         \
56        _proftime = ( _profstop.tv_sec*1000000+_profstop.tv_usec) -     \
57                ( x.tv_sec*1000000+x.tv_usec); \
58        elog(NOTICE, \
59                "PRF(%s) %lu (%f ms)", \
60                (n), \
61             _proftime, _proftime / 1000.0);    \
62        } while (0);
63
64#else
65
66#define profstart(x) do { } while (0);
67#define profstop(n, x) do { } while (0);
68
69#endif // PROFILE
70
71// ------------------------------------------------------------------------
72
73Datum tsp(PG_FUNCTION_ARGS);
74
75#undef DEBUG
76//#define DEBUG 1
77
78#ifdef DEBUG
79#define DBG(format, arg...)                     \
80    elog(NOTICE, format , ## arg)
81#else
82#define DBG(format, arg...) do { ; } while (0)
83#endif
84
85// The number of tuples to fetch from the SPI cursor at each iteration
86#define TUPLIMIT 1000
87
88// Apologies for using fixed-length arrays.  But this is an example, not
89// production code ;)
90//#define MAX_TOWNS 40
91
92float DISTANCE[MAX_TOWNS][MAX_TOWNS];
93float x[MAX_TOWNS],y[MAX_TOWNS];
94int total_tuples;
95
96
97static char *
98text2char(text *in)
99{
100  char *out = (char*)palloc(VARSIZE(in));
101
102  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
103  out[VARSIZE(in) - VARHDRSZ] = '\0';
104  return out;
105}
106
107static int
108finish(int code, int ret)
109{
110  code = SPI_finish();
111  if (code  != SPI_OK_FINISH )
112  {
113    elog(ERROR,"couldn't disconnect from SPI");
114    return -1 ;
115  }
116  return ret;
117}
118                 
119
120typedef struct point_columns
121{
122  int id;
123  float8 x;
124  float8 y;
125} point_columns_t;
126
127
128static int
129fetch_point_columns(SPITupleTable *tuptable, point_columns_t *point_columns)
130{
131  DBG("Fetching point");
132
133  point_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "source_id");
134  point_columns->x = SPI_fnumber(SPI_tuptable->tupdesc, "x");
135  point_columns->y = SPI_fnumber(SPI_tuptable->tupdesc, "y");
136  if (point_columns->id == SPI_ERROR_NOATTRIBUTE ||
137      point_columns->x == SPI_ERROR_NOATTRIBUTE ||
138      point_columns->y == SPI_ERROR_NOATTRIBUTE)
139    {
140      elog(ERROR, "Error, query must return columns "
141           "'source_id', 'x' and 'y'");
142      return -1;
143    }
144   
145  DBG("* Point %i [%f, %f]", point_columns->id, point_columns->x,
146      point_columns->y);
147
148  return 0;
149}
150
151static void
152fetch_point(HeapTuple *tuple, TupleDesc *tupdesc,
153            point_columns_t *point_columns, point_t *point)
154{
155  Datum binval;
156  bool isnull;
157
158  DBG("inside fetch_point\n");
159
160  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->id, &isnull);
161  DBG("Got id\n");
162
163  if (isnull)
164    elog(ERROR, "id contains a null value");
165
166  point->id = DatumGetInt32(binval);
167
168  DBG("id = %i\n", point->id);
169
170  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->x, &isnull);
171  if (isnull)
172    elog(ERROR, "x contains a null value");
173
174  point->x = DatumGetFloat8(binval);
175
176  DBG("x = %f\n", point->x);
177
178  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->y, &isnull);
179
180  if (isnull)
181    elog(ERROR, "y contains a null value");
182
183  point->y = DatumGetFloat8(binval);
184
185  DBG("y = %f\n", point->y);
186}
187
188
189static int solve_tsp(char* sql, char* p_ids,
190                     int source, path_element_t* path)
191{
192  int SPIcode;
193  void *SPIplan;
194  Portal SPIportal;
195  bool moredata = TRUE;
196  int ntuples;
197
198  //todo replace path (vector of path_element_t) with this array
199  int ids[MAX_TOWNS];
200
201  point_t *points=NULL;
202  point_columns_t point_columns = {id: -1, x: -1, y:-1};
203
204  char *err_msg = NULL;
205  int ret = -1;
206   
207  char *p;
208  int   z = 0;
209 
210  int    tt, cc;
211  double dx, dy;
212  float  fit=0.0;
213
214  DBG("inside tsp\n");
215
216  //int total_tuples = 0;
217  total_tuples = 0;
218
219  p = strtok(p_ids, ",");
220  while(p != NULL)
221    {
222      //      ((path_element_t*)path)[z].vertex_id = atoi(p);
223      ids[z]=atoi(p);
224      p = strtok(NULL, ",");
225      z++;
226      if(z >= MAX_TOWNS)
227      {
228        elog(ERROR, "Number of points exeeds max number.");
229        break;
230      }
231    }
232   
233  DBG("ZZZ %i\n",z);
234  DBG("start tsp\n");
235       
236  SPIcode = SPI_connect();
237
238  if (SPIcode  != SPI_OK_CONNECT)
239    {
240      elog(ERROR, "tsp: couldn't open a connection to SPI");
241      return -1;
242    }
243
244  SPIplan = SPI_prepare(sql, 0, NULL);
245
246  if (SPIplan  == NULL)
247    {
248      elog(ERROR, "tsp: couldn't create query plan via SPI");
249      return -1;
250    }
251
252  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
253    {
254      elog(ERROR, "tsp: SPI_cursor_open('%s') returns NULL", sql);
255      return -1;
256    }
257   
258  DBG("Query: %s\n",sql);
259  DBG("Query executed\n");
260
261  while (moredata == TRUE)
262    {
263      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
264
265      if (point_columns.id == -1)
266        {
267          if (fetch_point_columns(SPI_tuptable, &point_columns) == -1)
268            return finish(SPIcode, ret);
269        }
270
271      ntuples = SPI_processed;
272
273      total_tuples += ntuples;
274
275      DBG("Tuples: %i\n", total_tuples);
276
277      if (!points)
278        points = palloc(total_tuples * sizeof(point_t));
279      else
280        points = repalloc(points, total_tuples * sizeof(point_t));
281                                       
282      if (points == NULL)
283        {
284          elog(ERROR, "Out of memory");
285          return finish(SPIcode, ret);
286        }
287
288      if (ntuples > 0)
289        {
290          int t;
291          SPITupleTable *tuptable = SPI_tuptable;
292          TupleDesc tupdesc = SPI_tuptable->tupdesc;
293
294          DBG("Got tuple desc\n");
295               
296          for (t = 0; t < ntuples; t++)
297            {
298              HeapTuple tuple = tuptable->vals[t];
299              DBG("Before point fetched\n");
300              fetch_point(&tuple, &tupdesc, &point_columns,
301                          &points[total_tuples - ntuples + t]);
302              DBG("Point fetched\n");
303            }
304
305          SPI_freetuptable(tuptable);
306        }
307      else
308        {
309          moredata = FALSE;
310        }                                       
311    }
312
313 
314  DBG("Calling TSP\n");
315       
316  profstop("extract", prof_extract);
317  profstart(prof_tsp);
318
319  DBG("Total tuples: %i\n", total_tuples);
320
321  for(tt=0;tt<total_tuples;++tt)
322    {
323      //((path_element_t*)path)[tt].vertex_id = points[tt].id;
324      ids[tt] = points[tt].id;
325      x[tt] = points[tt].x;
326      y[tt] = points[tt].y;
327 
328      DBG("Point at %i: %i [%f, %f]\n",  tt, ids[tt], x[tt], y[tt]);
329           
330      // ((path_element_t*)path)[tt].vertex_id, x[tt], y[tt]);
331
332      for(cc=0;cc<total_tuples;++cc)
333        {
334          dx=x[tt]-x[cc]; dy=y[tt]-y[cc];
335          DISTANCE[tt][cc] = DISTANCE[cc][tt] = sqrt(dx*dx+dy*dy);
336        }
337    }
338
339  DBG("DISTANCE counted\n");
340  pfree(points);
341   
342  //ret = find_tsp_solution(total_tuples, DISTANCE,
343  //   path, source, &fit, &err_msg);
344
345  ret = find_tsp_solution(total_tuples, DISTANCE, ids,
346                          source, &fit, err_msg);
347
348  for(tt=0;tt<total_tuples;++tt)
349    {
350      ((path_element_t*)path)[tt].vertex_id = ids[tt];
351    }
352   
353  DBG("TSP solved!\n");
354  DBG("Score: %f\n", fit);
355
356  profstop("tsp", prof_tsp);
357  profstart(prof_store);
358
359  DBG("Profile changed and ret is %i", ret);
360
361  if (ret < 0)
362    {
363      //elog(ERROR, "Error computing path: %s", err_msg);
364      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED), errmsg("Error computing path: %s", err_msg)));
365    }
366
367  return finish(SPIcode, ret);   
368}
369
370PG_FUNCTION_INFO_V1(tsp);
371Datum
372tsp(PG_FUNCTION_ARGS)
373{
374  FuncCallContext     *funcctx;
375  int                  call_cntr;
376  int                  max_calls;
377  TupleDesc            tuple_desc;
378  path_element_t        *path;
379   
380  /* stuff done only on the first call of the function */
381  if (SRF_IS_FIRSTCALL())
382    {
383      MemoryContext   oldcontext;
384      //int path_count;
385      int ret=-1;
386
387      // XXX profiling messages are not thread safe
388      profstart(prof_total);
389      profstart(prof_extract);
390
391      /* create a function context for cross-call persistence */
392      funcctx = SRF_FIRSTCALL_INIT();
393
394      /* switch to memory context appropriate for multiple function calls */
395      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
396       
397
398      path = (path_element_t *)palloc(sizeof(path_element_t)*MAX_TOWNS);
399
400
401      ret = solve_tsp(text2char(PG_GETARG_TEXT_P(0)),
402                      text2char(PG_GETARG_TEXT_P(1)),
403                      PG_GETARG_INT32(2),
404                      path);
405
406      /* total number of tuples to be returned */
407      DBG("Counting tuples number\n");
408
409      funcctx->max_calls = total_tuples;
410
411      funcctx->user_fctx = path;
412
413      funcctx->tuple_desc = BlessTupleDesc(
414                              RelationNameGetTupleDesc("path_result"));
415      MemoryContextSwitchTo(oldcontext);
416    }
417
418  /* stuff done on every call of the function */
419  funcctx = SRF_PERCALL_SETUP();
420
421  call_cntr = funcctx->call_cntr;
422  max_calls = funcctx->max_calls;
423  tuple_desc = funcctx->tuple_desc;
424
425  path = (path_element_t *)funcctx->user_fctx;
426
427  DBG("Trying to allocate some memory\n");
428  DBG("call_cntr = %i, max_calls = %i\n", call_cntr, max_calls);
429
430  if (call_cntr < max_calls)    /* do when there is more left to send */
431    {
432      HeapTuple    tuple;
433      Datum        result;
434      Datum *values;
435      char* nulls;
436
437      /* This will work for some compilers. If it crashes with segfault, try to change the following block with this one   
438
439      values = palloc(4 * sizeof(Datum));
440      nulls = palloc(4 * sizeof(char));
441
442      values[0] = call_cntr;
443      nulls[0] = ' ';
444      values[1] = Int32GetDatum(path[call_cntr].vertex_id);
445      nulls[1] = ' ';
446      values[2] = Int32GetDatum(path[call_cntr].edge_id);
447      nulls[2] = ' ';
448      values[3] = Float8GetDatum(path[call_cntr].cost);
449      nulls[3] = ' ';
450      */
451   
452      values = palloc(3 * sizeof(Datum));
453      nulls = palloc(3 * sizeof(char));
454
455      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
456      nulls[0] = ' ';
457      values[1] = Int32GetDatum(path[call_cntr].edge_id);
458      nulls[1] = ' ';
459      values[2] = Float8GetDatum(path[call_cntr].cost);
460      nulls[2] = ' ';
461     
462      DBG("Heap making\n");
463
464      tuple = heap_formtuple(tuple_desc, values, nulls);
465
466      DBG("Datum making\n");
467
468      /* make the tuple into a datum */
469      result = HeapTupleGetDatum(tuple);
470
471      DBG("VAL: %i\n, %i", values[0], result);
472      DBG("Trying to free some memory\n");
473   
474      /* clean up (this is not really necessary) */
475      pfree(values);
476      pfree(nulls);
477       
478
479      SRF_RETURN_NEXT(funcctx, result);
480    }
481  else    /* do when there is no more left */
482    {
483      DBG("Ending function\n");
484      profstop("store", prof_store);
485      profstop("total", prof_total);
486      DBG("Profiles stopped\n");
487
488      pfree(path);
489
490      DBG("Path cleared\n");
491
492      SRF_RETURN_DONE(funcctx);
493    }
494}
Note: See TracBrowser for help on using the browser.